Revert raylet to worker GRPC communication back to asio (#5450)

This commit is contained in:
Philipp Moritz
2019-08-17 19:11:32 -07:00
committed by GitHub
parent 03d05c8765
commit 599cc2be60
45 changed files with 1418 additions and 1764 deletions
+12 -43
View File
@@ -65,20 +65,6 @@ cc_proto_library(
deps = [":object_manager_proto"],
)
proto_library(
name = "raylet_proto",
srcs = ["src/ray/protobuf/raylet.proto"],
deps = [
":common_proto",
":gcs_proto",
],
)
cc_proto_library(
name = "raylet_cc_proto",
deps = [":raylet_proto"],
)
proto_library(
name = "worker_proto",
srcs = ["src/ray/protobuf/worker.proto"],
@@ -182,34 +168,6 @@ cc_library(
],
)
# Raylet gRPC lib.
cc_grpc_library(
name = "raylet_cc_grpc",
srcs = [":raylet_proto"],
grpc_only = True,
deps = [":raylet_cc_proto"],
)
# Raylet rpc server and client.
cc_library(
name = "raylet_rpc",
srcs = glob([
"src/ray/rpc/raylet/*.cc",
]),
hdrs = glob([
"src/ray/rpc/raylet/*.h",
"src/ray/raylet/*.h",
]),
copts = COPTS,
deps = [
":grpc_common_lib",
":ray_common",
":raylet_cc_grpc",
"@boost//:asio",
"@com_github_grpc_grpc//:grpc++",
],
)
# Worker gRPC lib.
cc_grpc_library(
name = "worker_cc_grpc",
@@ -263,6 +221,7 @@ cc_library(
copts = COPTS,
deps = [
":common_cc_proto",
":node_manager_fbs",
":ray_util",
"@boost//:asio",
"@com_github_grpc_grpc//:grpc++",
@@ -361,11 +320,11 @@ cc_library(
deps = [
":common_cc_proto",
":gcs",
":node_manager_fbs",
":node_manager_rpc",
":object_manager",
":ray_common",
":ray_util",
":raylet_rpc",
":stats_lib",
":worker_rpc",
"@boost//:asio",
@@ -461,6 +420,7 @@ cc_test(
srcs = ["src/ray/raylet/lineage_cache_test.cc"],
copts = COPTS,
deps = [
":node_manager_fbs",
":raylet_lib",
"@com_google_googletest//:gtest_main",
],
@@ -471,6 +431,7 @@ cc_test(
srcs = ["src/ray/raylet/reconstruction_policy_test.cc"],
copts = COPTS,
deps = [
":node_manager_fbs",
":object_manager",
":raylet_lib",
"@com_google_googletest//:gtest_main",
@@ -670,6 +631,7 @@ cc_library(
deps = [
":gcs_cc_proto",
":hiredis",
":node_manager_fbs",
":node_manager_rpc",
":ray_common",
":ray_util",
@@ -725,6 +687,13 @@ flatbuffer_cc_library(
out_prefix = "src/ray/common/",
)
flatbuffer_cc_library(
name = "node_manager_fbs",
srcs = ["src/ray/raylet/format/node_manager.fbs"],
flatc_args = FLATC_ARGS,
out_prefix = "src/ray/raylet/format/",
)
flatbuffer_cc_library(
name = "object_manager_fbs",
srcs = ["src/ray/object_manager/format/object_manager.fbs"],
@@ -24,7 +24,11 @@ public class BaseTest {
// These files need to be deleted after each test case.
filesToDelete = ImmutableList.of(
new File(Ray.getRuntimeContext().getRayletSocketName()),
new File(Ray.getRuntimeContext().getObjectStoreSocketName())
new File(Ray.getRuntimeContext().getObjectStoreSocketName()),
// TODO(pcm): This is a workaround for the issue described
// in the PR description of https://github.com/ray-project/ray/pull/5450
// and should be fixed properly.
new File("/tmp/ray/test/raylet_socket")
);
// Make sure the files will be deleted even if the test doesn't exit gracefully.
filesToDelete.forEach(File::deleteOnExit);
@@ -102,4 +102,3 @@ public class ResourcesManagementTest extends BaseTest {
}
}
+1 -1
View File
@@ -374,7 +374,7 @@ cdef class RayletClient:
@property
def client_id(self):
return ClientID(self.client.get().GetWorkerId().Binary())
return ClientID(self.client.get().GetWorkerID().Binary())
@property
def job_id(self):
+5 -5
View File
@@ -88,16 +88,16 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
cdef extern from "ray/protobuf/common.pb.h" nogil:
cdef cppclass CLanguage "ray::rpc::Language":
cdef cppclass CLanguage "Language":
pass
# This is a workaround for C++ enum class since Cython has no corresponding
# representation.
cdef extern from "ray/protobuf/common.pb.h" namespace "ray::rpc::Language" nogil:
cdef CLanguage LANGUAGE_PYTHON "ray::rpc::Language::PYTHON"
cdef CLanguage LANGUAGE_CPP "ray::rpc::Language::CPP"
cdef CLanguage LANGUAGE_JAVA "ray::rpc::Language::JAVA"
cdef extern from "ray/protobuf/common.pb.h" namespace "Language" nogil:
cdef CLanguage LANGUAGE_PYTHON "Language::PYTHON"
cdef CLanguage LANGUAGE_CPP "Language::CPP"
cdef CLanguage LANGUAGE_JAVA "Language::JAVA"
cdef extern from "ray/common/task/scheduling_resources.h" \
+6 -6
View File
@@ -23,14 +23,14 @@ from ray.includes.task cimport CTaskSpec
cdef extern from "ray/protobuf/gcs.pb.h" nogil:
cdef cppclass GCSProfileEvent "ray::rpc::ProfileTableData::ProfileEvent":
cdef cppclass GCSProfileEvent "ProfileTableData::ProfileEvent":
void set_event_type(const c_string &value)
void set_start_time(double value)
void set_end_time(double value)
c_string set_extra_data(const c_string &value)
GCSProfileEvent()
cdef cppclass GCSProfileTableData "ray::rpc::ProfileTableData":
cdef cppclass GCSProfileTableData "ProfileTableData":
void set_component_type(const c_string &value)
void set_component_id(const c_string &value)
void set_node_ip_address(const c_string &value)
@@ -43,12 +43,13 @@ ctypedef unordered_map[c_string, c_vector[pair[int64_t, double]]] \
ctypedef pair[c_vector[CObjectID], c_vector[CObjectID]] WaitResultPair
cdef extern from "ray/rpc/raylet/raylet_client.h" namespace "ray::rpc" nogil:
cdef cppclass CRayletClient "ray::rpc::RayletClient":
cdef extern from "ray/raylet/raylet_client.h" nogil:
cdef cppclass CRayletClient "RayletClient":
CRayletClient(const c_string &raylet_socket,
const CWorkerID &worker_id,
c_bool is_worker, const CJobID &job_id,
const CLanguage &language)
CRayStatus Disconnect()
CRayStatus SubmitTask(const CTaskSpec &task_spec)
CRayStatus GetTask(unique_ptr[CTaskSpec] *task_spec)
CRayStatus TaskDone()
@@ -72,8 +73,7 @@ cdef extern from "ray/rpc/raylet/raylet_client.h" namespace "ray::rpc" nogil:
const CActorID &actor_id, const CActorCheckpointID &checkpoint_id)
CRayStatus SetResource(const c_string &resource_name, const double capacity, const CClientID &client_Id)
CLanguage GetLanguage() const
CWorkerID GetWorkerId() const
CWorkerID GetWorkerID() const
CJobID GetJobID() const
c_bool IsWorker() const
CRayStatus Disconnect()
const ResourceMappingType &GetResourceIDs() const
-11
View File
@@ -82,17 +82,6 @@ inline std::vector<T> VectorFromProtobuf(
return std::vector<T>(pb_repeated.begin(), pb_repeated.end());
}
template <typename Message>
using AddFunction = void (Message::*)(const ::std::string &value);
/// Add a vector of type ID to protobuf message.
template <typename ID, typename Message>
inline void IdVectorToProtobuf(const std::vector<ID> &ids, Message &message,
AddFunction<Message> add_func) {
for (const auto &id : ids) {
(message.*add_func)(id.Binary());
}
}
/// Converts a Protobuf `RepeatedField` to a vector of IDs.
template <class ID>
inline std::vector<ID> IdVectorFromProtobuf(
+1 -8
View File
@@ -20,12 +20,8 @@ RAY_CONFIG(int64_t, ray_cookie, 0x5241590000000000)
/// warning is logged that the handler is taking too long.
RAY_CONFIG(int64_t, handler_warning_timeout_ms, 100)
/// The duration between heartbeats. This value is used for both worker and raylet.
/// The duration between heartbeats. These are sent by the raylet.
RAY_CONFIG(int64_t, heartbeat_timeout_milliseconds, 100)
/// Worker heartbeats also use `heartbeat_timeout_milliseconds` as timer timeout period.
/// If a worker has not sent a heartbeat in the last `num_worker_heartbeats_timeout`
/// heartbeat intervals, raylet will mark this worker as dead.
RAY_CONFIG(int64_t, num_worker_heartbeats_timeout, 30)
/// If a component has not sent a heartbeat in the last num_heartbeats_timeout
/// heartbeat intervals, the raylet monitor process will report
/// it as dead to the db_client table.
@@ -156,9 +152,6 @@ RAY_CONFIG(int32_t, num_actor_checkpoints_to_keep, 20)
/// Maximum number of ids in one batch to send to GCS to delete keys.
RAY_CONFIG(uint32_t, maximum_gcs_deletion_batch_size, 1000)
/// Number of times for a raylet client to retry to register.
RAY_CONFIG(int, num_raylet_client_retry_times, 25)
/// When getting objects from object store, print a warning every this number of attempts.
RAY_CONFIG(uint32_t, object_store_get_warn_per_num_attempts, 50)
+25 -10
View File
@@ -674,22 +674,37 @@ std::string ResourceIdSet::ToString() const {
return return_string;
}
std::vector<rpc::ResourceIdSetInfo> ResourceIdSet::ToProtobuf() const {
std::vector<rpc::ResourceIdSetInfo> resources;
std::vector<flatbuffers::Offset<protocol::ResourceIdSetInfo>> ResourceIdSet::ToFlatbuf(
flatbuffers::FlatBufferBuilder &fbb) const {
std::vector<flatbuffers::Offset<protocol::ResourceIdSetInfo>> return_message;
for (auto const &resource_pair : available_resources_) {
rpc::ResourceIdSetInfo resource_id_set_info;
resource_id_set_info.set_resource_name(resource_pair.first);
std::vector<int64_t> resource_ids;
std::vector<double> resource_fractions;
for (auto whole_id : resource_pair.second.WholeIds()) {
resource_id_set_info.add_resource_ids(whole_id);
resource_id_set_info.add_resource_fractions(1);
resource_ids.push_back(whole_id);
resource_fractions.push_back(1);
}
for (auto const &fractional_pair : resource_pair.second.FractionalIds()) {
resource_id_set_info.add_resource_ids(fractional_pair.first);
resource_id_set_info.add_resource_fractions(fractional_pair.second.ToDouble());
resource_ids.push_back(fractional_pair.first);
resource_fractions.push_back(fractional_pair.second.ToDouble());
}
resources.emplace_back(resource_id_set_info);
auto resource_id_set_message = protocol::CreateResourceIdSetInfo(
fbb, fbb.CreateString(resource_pair.first), fbb.CreateVector(resource_ids),
fbb.CreateVector(resource_fractions));
return_message.push_back(resource_id_set_message);
}
return resources;
return return_message;
}
const std::string ResourceIdSet::Serialize() const {
flatbuffers::FlatBufferBuilder fbb;
auto resource_id_set_flatbuf = ToFlatbuf(fbb);
fbb.Finish(fbb.CreateVector(resource_id_set_flatbuf));
return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize());
}
/// SchedulingResources class implementation
+12 -4
View File
@@ -6,7 +6,7 @@
#include <unordered_set>
#include <vector>
#include "ray/protobuf/common.pb.h"
#include "ray/raylet/format/node_manager_generated.h"
namespace ray {
@@ -422,10 +422,18 @@ class ResourceIdSet {
/// \return A human-readable string version of the object.
std::string ToString() const;
/// \brief Convert this object to a vector of protobuf `ResourceIdSetInfo`s.
/// \brief Serialize this object using flatbuffers.
///
/// \return A vector inclusing resource id set infos.
std::vector<rpc::ResourceIdSetInfo> ToProtobuf() const;
/// \param fbb A flatbuffer builder object.
/// \return A flatbuffer serialized version of this object.
std::vector<flatbuffers::Offset<ray::protocol::ResourceIdSetInfo>> ToFlatbuf(
flatbuffers::FlatBufferBuilder &fbb) const;
/// \brief Serialize this object as a string.
///
/// \return A serialized string of this object.
/// TODO(zhijunfu): this can be removed after raylet client is migrated to grpc.
const std::string Serialize() const;
private:
/// A mapping from resource name to a set of resource IDs for that resource.
+1 -1
View File
@@ -3,9 +3,9 @@
#include <inttypes.h>
#include "ray/common/task/task_common.h"
#include "ray/common/task/task_execution_spec.h"
#include "ray/common/task/task_spec.h"
#include "ray/protobuf/common.pb.h"
namespace ray {
+1 -1
View File
@@ -6,7 +6,7 @@
#include "ray/common/buffer.h"
#include "ray/common/id.h"
#include "ray/common/task/task_spec.h"
#include "ray/rpc/raylet/raylet_client.h"
#include "ray/raylet/raylet_client.h"
#include "ray/util/util.h"
namespace ray {
+1 -3
View File
@@ -8,12 +8,10 @@
#include "ray/core_worker/task_execution.h"
#include "ray/core_worker/task_interface.h"
#include "ray/gcs/redis_gcs_client.h"
#include "ray/rpc/raylet/raylet_client.h"
#include "ray/raylet/raylet_client.h"
namespace ray {
using rpc::RayletClient;
/// The root class that contains all the core and language-independent functionalities
/// of the worker. This class is supposed to be used to implement app-language (Java,
/// Python, etc) workers.
@@ -4,9 +4,9 @@
#include "ray/core_worker/common.h"
#include "ray/core_worker/core_worker.h"
#include "ray/core_worker/lib/java/jni_utils.h"
#include "ray/rpc/raylet/raylet_client.h"
#include "ray/raylet/raylet_client.h"
inline ray::RayletClient &GetRayletClientFromPointer(jlong nativeCoreWorkerPointer) {
inline RayletClient &GetRayletClientFromPointer(jlong nativeCoreWorkerPointer) {
return reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)->GetRayletClient();
}
-2
View File
@@ -11,8 +11,6 @@
namespace ray {
using rpc::RayletClient;
class CoreWorker;
class CoreWorkerStoreProvider;
@@ -7,7 +7,7 @@
#include "ray/common/status.h"
#include "ray/core_worker/common.h"
#include "ray/core_worker/store_provider/store_provider.h"
#include "ray/rpc/raylet/raylet_client.h"
#include "ray/raylet/raylet_client.h"
namespace ray {
@@ -8,12 +8,10 @@
#include "ray/core_worker/common.h"
#include "ray/core_worker/store_provider/local_plasma_provider.h"
#include "ray/core_worker/store_provider/store_provider.h"
#include "ray/rpc/raylet/raylet_client.h"
#include "ray/raylet/raylet_client.h"
namespace ray {
using rpc::RayletClient;
class CoreWorker;
/// The class provides implementations for accessing plasma store, which includes both
-2
View File
@@ -17,8 +17,6 @@
namespace ray {
using rpc::RayletClient;
class CoreWorker;
/// Options of a non-actor-creation task.
+1 -1
View File
@@ -10,7 +10,7 @@
#include "ray/core_worker/store_provider/local_plasma_provider.h"
#include "ray/core_worker/store_provider/memory_store_provider.h"
#include "ray/rpc/raylet/raylet_client.h"
#include "ray/raylet/raylet_client.h"
#include "src/ray/protobuf/direct_actor.grpc.pb.h"
#include "src/ray/protobuf/direct_actor.pb.h"
#include "src/ray/util/test_util.h"
@@ -5,13 +5,11 @@
#include "ray/core_worker/object_interface.h"
#include "ray/core_worker/transport/transport.h"
#include "ray/rpc/raylet/raylet_client.h"
#include "ray/raylet/raylet_client.h"
#include "ray/rpc/worker/worker_server.h"
namespace ray {
using rpc::RayletClient;
/// In raylet task submitter and receiver, a task is submitted to raylet, and possibly
/// gets forwarded to another raylet on which node the task should be executed, and
/// then a worker on that node gets this task and starts executing it.
+3 -2
View File
@@ -769,8 +769,9 @@ void ObjectManager::SpreadFreeObjectsRequest(
const std::vector<std::shared_ptr<rpc::ObjectManagerClient>> &rpc_clients) {
// This code path should be called from node manager.
rpc::FreeObjectsRequest free_objects_request;
IdVectorToProtobuf<ObjectID, rpc::FreeObjectsRequest>(
object_ids, free_objects_request, &rpc::FreeObjectsRequest::add_object_ids);
for (const auto &e : object_ids) {
free_objects_request.add_object_ids(e.Binary());
}
for (auto &rpc_client : rpc_clients) {
rpc_client->FreeObjects(free_objects_request, [](const Status &status,
-12
View File
@@ -11,18 +11,6 @@ enum Language {
CPP = 2;
}
// Resource id set info.
message ResourceIdSetInfo {
// The name of the resource.
bytes resource_name = 1;
// The resource IDs reserved for this worker.
repeated uint64 resource_ids = 2;
// The fraction of each resource ID that is reserved for this worker. Note
// that the length of this list must be the same as the length of
// resource_ids.
repeated double resource_fractions = 3;
}
// Type of a worker.
enum WorkerType {
WORKER = 0;
-215
View File
@@ -1,215 +0,0 @@
syntax = "proto3";
package ray.rpc;
import "src/ray/protobuf/common.proto";
import "src/ray/protobuf/gcs.proto";
/// NOTE(Joey Jiang) Every request defined in this file should have a `worker_id` field,
/// which will be used in `NodeManager::PreprocessRequest`.
/// Service request and reply messages.
message RegisterClientRequest {
// The worker id.
bytes worker_id = 1;
// Indicates the client is a worker or a driver.
bool is_worker = 2;
// The process ID of this worker.
uint32 worker_pid = 3;
// The job ID.
bytes job_id = 4;
// Language of this worker.
Language language = 5;
// Port that this worker is listening on.
// If port > 0, then worker will listen to this port and wait for
// raylet to push tasks, instead of invoking GetTask().
int32 port = 6;
}
message RegisterClientReply {
repeated int32 gpu_ids = 1;
}
message SubmitTaskRequest {
bytes worker_id = 1;
TaskSpec task_spec = 2;
}
message SubmitTaskReply {
}
message DisconnectClientRequest {
bytes worker_id = 1;
}
message DisconnectClientReply {
}
message GetTaskRequest {
bytes worker_id = 1;
}
message GetTaskReply {
// A string of bytes representing the task specification.
bytes task_spec = 1;
// A list of the resources reserved for this worker.
repeated ResourceIdSetInfo fractional_resource_ids = 2;
}
message TaskDoneRequest {
bytes worker_id = 1;
}
message TaskDoneReply {
}
message FetchOrReconstructRequest {
// The worker ID.
bytes worker_id = 1;
// List of object IDs of the objects that we want to reconstruct or fetch.
repeated bytes object_ids = 2;
// Indicates that we only want to fetch objects, not reconstruct them.
bool fetch_only = 3;
// The current task ID. If fetch_only is false, then this task is blocked.
bytes task_id = 4;
}
message FetchOrReconstructReply {
}
message NotifyUnblockedRequest {
bytes worker_id = 1;
// The current task ID. This task is no longer blocked.
bytes task_id = 2;
}
message NotifyUnblockedReply {
}
message WaitRequest {
// The worker ID.
bytes worker_id = 1;
// List of object ids we'll be waiting on.
repeated bytes object_ids = 2;
// Number of objects expected to be returned, if available.
uint64 num_ready_objects = 3;
// Timeout in milliseconds.
int64 timeout = 4;
// Whether to wait until objects appear locally.
bool wait_local = 5;
// The current task ID. If there are less than num_ready_objects local, then
// this task is blocked.
bytes task_id = 6;
}
message WaitReply {
// List of object ids found.
repeated bytes found = 1;
// List of object ids not found.
repeated bytes remaining = 2;
}
message PushErrorRequest {
// The worker ID.
bytes worker_id = 1;
// The job id that the error is for.
bytes job_id = 2;
// The type of the error.
bytes type = 3;
// The error message.
bytes error_message = 4;
// The timestamp of the error message.
double timestamp = 5;
}
message PushErrorReply {
}
message PushProfileEventsRequest {
bytes worker_id = 1;
ProfileTableData profile_table_data = 2;
}
message PushProfileEventsReply {
}
message FreeObjectsInStoreRequest {
// The worker ID.
bytes worker_id = 1;
// Whether keep this request within the local object store
// or send it to all of the object stores.
bool local_only = 2;
// Whether also delete objects' creating tasks from GCS.
bool delete_creating_tasks = 3;
// List of object ids to delete from the object store.
repeated bytes object_ids = 4;
}
message FreeObjectsInStoreReply {
}
message PrepareActorCheckpointRequest {
bytes worker_id = 1;
bytes actor_id = 2;
}
message PrepareActorCheckpointReply {
bytes worker_id = 1;
bytes checkpoint_id = 2;
}
message NotifyActorResumedFromCheckpointRequest {
bytes worker_id = 1;
// ID of the actor that resumed.
bytes actor_id = 2;
// ID of the checkpoint from which the actor was resumed.
bytes checkpoint_id = 3;
}
message NotifyActorResumedFromCheckpointReply {
}
message SetResourceRequest {
bytes worker_id = 1;
// Name of the resource to be set.
bytes resource_name = 2;
// Capacity of the resource to be set.
double capacity = 3;
// Client ID where this resource will be set.
bytes client_id = 4;
}
message SetResourceReply {
}
message HeartbeatRequest {
bytes worker_id = 1;
bool is_worker = 2;
}
message HeartbeatReply {
}
/// Worker-to-raylet RPC service interface.
service RayletService {
// Register a new worker to the raylet.
rpc RegisterClient(RegisterClientRequest) returns (RegisterClientReply);
// Submit a task to the raylet.
rpc SubmitTask(SubmitTaskRequest) returns (SubmitTaskReply);
// Disconnect this client from raylet gracefully.
rpc DisconnectClient(DisconnectClientRequest) returns (DisconnectClientReply);
// Get a new task from the raylet.
rpc GetTask(GetTaskRequest) returns (GetTaskReply);
// Notify the raylet that a task is finished.
rpc TaskDone(TaskDoneRequest) returns (TaskDoneReply);
// Reconstruct or fetch possibly lost objects.
rpc FetchOrReconstruct(FetchOrReconstructRequest) returns (FetchOrReconstructReply);
// For a worker that was blocked on some object(s), tell the raylet
// that the worker is now unblocked.
rpc NotifyUnblocked(NotifyUnblockedRequest) returns (NotifyUnblockedReply);
// Wait for objects to be ready either from local or remote plasma stores.
// The `WaitReply` contains the objects found and objects remaining.
rpc Wait(WaitRequest) returns (WaitReply);
// Push an error to the relevant driver.
rpc PushError(PushErrorRequest) returns (PushErrorReply);
// Push some profiling events to the GCS. When sending this message to the
// node manager, the message itself is serialized as a ProfileTableData object.
rpc PushProfileEvents(PushProfileEventsRequest) returns (PushProfileEventsReply);
// Free the objects in plasma objects store.
rpc FreeObjectsInStore(FreeObjectsInStoreRequest) returns (FreeObjectsInStoreReply);
// Request raylet backend to prepare a checkpoint for an actor.
rpc PrepareActorCheckpoint(PrepareActorCheckpointRequest)
returns (PrepareActorCheckpointReply);
// Notify raylet backend that an actor was resumed from a checkpoint.
rpc NotifyActorResumedFromCheckpoint(NotifyActorResumedFromCheckpointRequest)
returns (NotifyActorResumedFromCheckpointReply);
// Set dynamic custom resource.
rpc SetResource(SetResourceRequest) returns (SetResourceReply);
// Send a heartbeat message to raylet.
rpc Heartbeat(HeartbeatRequest) returns (HeartbeatReply);
}
+3 -1
View File
@@ -8,7 +8,9 @@ message AssignTaskRequest {
// The task to be pushed.
Task task = 1;
// A list of the resources reserved for this worker.
repeated ResourceIdSetInfo resource_ids = 2;
// TODO(zhijunfu): `resource_ids` is represented as
// flatbutters-serialized bytes, will be moved to protobuf later.
bytes resource_ids = 2;
}
message AssignTaskReply {
@@ -0,0 +1,292 @@
#include "ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h"
#include <jni.h>
#include "ray/common/id.h"
#include "ray/core_worker/lib/java/jni_utils.h"
#include "ray/raylet/raylet_client.h"
#include "ray/util/logging.h"
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
* Method: nativeInit
* Signature: (Ljava/lang/String;[BZ[B)J
*/
JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit(
JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker,
jbyteArray jobId) {
const auto worker_id = JavaByteArrayToId<WorkerID>(env, workerId);
const auto job_id = JavaByteArrayToId<JobID>(env, jobId);
const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE);
auto raylet_client = new std::unique_ptr<RayletClient>(
new RayletClient(nativeString, worker_id, isWorker, job_id, Language::JAVA));
env->ReleaseStringUTFChars(sockName, nativeString);
return reinterpret_cast<jlong>(raylet_client);
}
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
* Method: nativeSubmitTask
* Signature: (J[BLjava/nio/ByteBuffer;II)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask(
JNIEnv *env, jclass, jlong client, jbyteArray taskSpec) {
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
jbyte *data = env->GetByteArrayElements(taskSpec, NULL);
jsize size = env->GetArrayLength(taskSpec);
ray::rpc::TaskSpec task_spec_message;
task_spec_message.ParseFromArray(data, size);
env->ReleaseByteArrayElements(taskSpec, data, JNI_ABORT);
ray::TaskSpecification task_spec(task_spec_message);
auto status = raylet_client->SubmitTask(task_spec);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
* Method: nativeGetTask
* Signature: (J)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeGetTask(
JNIEnv *env, jclass, jlong client) {
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
std::unique_ptr<ray::TaskSpecification> spec;
auto status = raylet_client->GetTask(&spec);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
// Serialize the task spec and copy to Java byte array.
auto task_data = spec->Serialize();
jbyteArray result = env->NewByteArray(task_data.size());
if (result == nullptr) {
return nullptr; /* out of memory error thrown */
}
env->SetByteArrayRegion(result, 0, task_data.size(),
reinterpret_cast<const jbyte *>(task_data.data()));
return result;
}
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
* Method: nativeDestroy
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy(
JNIEnv *env, jclass, jlong client) {
auto raylet_client = reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
auto status = (*raylet_client)->Disconnect();
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
delete raylet_client;
}
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
* Method: nativeWaitObject
* Signature: (J[[BIIZ[B)[Z
*/
JNIEXPORT jbooleanArray JNICALL
Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject(
JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jint numReturns,
jint timeoutMillis, jboolean isWaitLocal, jbyteArray currentTaskId) {
std::vector<ObjectID> object_ids;
auto len = env->GetArrayLength(objectIds);
for (int i = 0; i < len; i++) {
jbyteArray object_id_bytes =
static_cast<jbyteArray>(env->GetObjectArrayElement(objectIds, i));
const auto object_id = JavaByteArrayToId<ObjectID>(env, object_id_bytes);
object_ids.push_back(object_id);
env->DeleteLocalRef(object_id_bytes);
}
const auto current_task_id = JavaByteArrayToId<TaskID>(env, currentTaskId);
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
// Invoke wait.
WaitResultPair result;
auto status =
raylet_client->Wait(object_ids, numReturns, timeoutMillis,
static_cast<bool>(isWaitLocal), current_task_id, &result);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
// Convert result to java object.
jboolean put_value = true;
jbooleanArray resultArray = env->NewBooleanArray(object_ids.size());
for (uint i = 0; i < result.first.size(); ++i) {
for (uint j = 0; j < object_ids.size(); ++j) {
if (result.first[i] == object_ids[j]) {
env->SetBooleanArrayRegion(resultArray, j, 1, &put_value);
break;
}
}
}
put_value = false;
for (uint i = 0; i < result.second.size(); ++i) {
for (uint j = 0; j < object_ids.size(); ++j) {
if (result.second[i] == object_ids[j]) {
env->SetBooleanArrayRegion(resultArray, j, 1, &put_value);
break;
}
}
}
return resultArray;
}
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
* Method: nativeGenerateActorCreationTaskId
* Signature: ([B[BI)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorCreationTaskId(
JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId,
jint parent_task_counter) {
const auto job_id = JavaByteArrayToId<JobID>(env, jobId);
const auto parent_task_id = JavaByteArrayToId<TaskID>(env, parentTaskId);
const ActorID actor_id = ray::ActorID::Of(job_id, parent_task_id, parent_task_counter);
const TaskID actor_creation_task_id = ray::TaskID::ForActorCreationTask(actor_id);
jbyteArray result = env->NewByteArray(actor_creation_task_id.Size());
if (nullptr == result) {
return nullptr;
}
env->SetByteArrayRegion(result, 0, actor_creation_task_id.Size(),
reinterpret_cast<const jbyte *>(actor_creation_task_id.Data()));
return result;
}
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
* Method: nativeGenerateActorTaskId
* Signature: ([B[BI[B)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorTaskId(
JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId,
jint parent_task_counter, jbyteArray actorId) {
const auto job_id = JavaByteArrayToId<JobID>(env, jobId);
const auto parent_task_id = JavaByteArrayToId<TaskID>(env, parentTaskId);
const auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
const TaskID actor_task_id =
ray::TaskID::ForActorTask(job_id, parent_task_id, parent_task_counter, actor_id);
jbyteArray result = env->NewByteArray(actor_task_id.Size());
if (nullptr == result) {
return nullptr;
}
env->SetByteArrayRegion(result, 0, actor_task_id.Size(),
reinterpret_cast<const jbyte *>(actor_task_id.Data()));
return result;
}
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
* Method: nativeGenerateNormalTaskId
* Signature: ([B[BI)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateNormalTaskId(
JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId,
jint parent_task_counter) {
const auto job_id = JavaByteArrayToId<JobID>(env, jobId);
const auto parent_task_id = JavaByteArrayToId<TaskID>(env, parentTaskId);
const TaskID task_id =
ray::TaskID::ForNormalTask(job_id, parent_task_id, parent_task_counter);
jbyteArray result = env->NewByteArray(task_id.Size());
if (nullptr == result) {
return nullptr;
}
env->SetByteArrayRegion(result, 0, task_id.Size(),
reinterpret_cast<const jbyte *>(task_id.Data()));
return result;
}
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
* Method: nativeFreePlasmaObjects
* Signature: (J[[BZZ)V
*/
JNIEXPORT void JNICALL
Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects(
JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jboolean localOnly,
jboolean deleteCreatingTasks) {
std::vector<ObjectID> object_ids;
auto len = env->GetArrayLength(objectIds);
for (int i = 0; i < len; i++) {
jbyteArray object_id_bytes =
static_cast<jbyteArray>(env->GetObjectArrayElement(objectIds, i));
const auto object_id = JavaByteArrayToId<ObjectID>(env, object_id_bytes);
object_ids.push_back(object_id);
env->DeleteLocalRef(object_id_bytes);
}
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
auto status = raylet_client->FreeObjects(object_ids, localOnly, deleteCreatingTasks);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
* Method: nativePrepareCheckpoint
* Signature: (J[B)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env, jclass,
jlong client,
jbyteArray actorId) {
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
const auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
ActorCheckpointID checkpoint_id;
auto status = raylet_client->PrepareActorCheckpoint(actor_id, checkpoint_id);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
jbyteArray result = env->NewByteArray(checkpoint_id.Size());
env->SetByteArrayRegion(result, 0, checkpoint_id.Size(),
reinterpret_cast<const jbyte *>(checkpoint_id.Data()));
return result;
}
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
* Method: nativeNotifyActorResumedFromCheckpoint
* Signature: (J[B[B)V
*/
JNIEXPORT void JNICALL
Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint(
JNIEnv *env, jclass, jlong client, jbyteArray actorId, jbyteArray checkpointId) {
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
const auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
const auto checkpoint_id = JavaByteArrayToId<ActorCheckpointID>(env, checkpointId);
auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
/*
* Class: org_ray_runtime_raylet_RayletClientImpl
* Method: nativeSetResource
* Signature: (JLjava/lang/String;D[B)V
*/
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource(
JNIEnv *env, jclass, jlong client, jstring resourceName, jdouble capacity,
jbyteArray nodeId) {
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
const auto node_id = JavaByteArrayToId<ClientID>(env, nodeId);
const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE);
auto status = raylet_client->SetResource(native_resource_name,
static_cast<double>(capacity), node_id);
env->ReleaseStringUTFChars(resourceName, native_resource_name);
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
}
#ifdef __cplusplus
}
#endif
+1
View File
@@ -7,6 +7,7 @@
#include "ray/common/task/task_execution_spec.h"
#include "ray/common/task/task_spec.h"
#include "ray/common/task/task_util.h"
#include "ray/raylet/format/node_manager_generated.h"
#include "ray/raylet/lineage_cache.h"
#include "ray/util/test_util.h"
File diff suppressed because it is too large Load Diff
+37 -98
View File
@@ -7,8 +7,6 @@
#include "ray/rpc/client_call.h"
#include "ray/rpc/node_manager/node_manager_server.h"
#include "ray/rpc/node_manager/node_manager_client.h"
#include "ray/rpc/raylet/raylet_server.h"
#include "ray/object_manager/object_manager.h"
#include "ray/common/task/task.h"
#include "ray/common/client_connection.h"
#include "ray/common/task/task_common.h"
@@ -66,8 +64,7 @@ struct NodeManagerConfig {
std::string session_dir;
};
class NodeManager : public rpc::NodeManagerServiceHandler,
public rpc::RayletServiceHandler {
class NodeManager : public rpc::NodeManagerServiceHandler {
public:
/// Create a node manager.
///
@@ -78,6 +75,23 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
std::shared_ptr<gcs::RedisGcsClient> gcs_client,
std::shared_ptr<ObjectDirectoryInterface> object_directory_);
/// Process a new client connection.
///
/// \param client The client to process.
/// \return Void.
void ProcessNewClient(LocalClientConnection &client);
/// Process a message from a client. This method is responsible for
/// explicitly listening for more messages from the client if the client is
/// still alive.
///
/// \param client The client that sent the message.
/// \param message_type The message type (e.g., a flatbuffer enum).
/// \param message_data A pointer to the message data.
/// \return Void.
void ProcessClientMessage(const std::shared_ptr<LocalClientConnection> &client,
int64_t message_type, const uint8_t *message_data);
/// Subscribe to the relevant GCS tables and set up handlers.
///
/// \return Status indicating whether this was done successfully or not.
@@ -91,91 +105,9 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
/// Record metrics.
void RecordMetrics() const;
public:
/// Get the port of the node manager rpc server.
int GetServerPort() const { return node_manager_server_.GetPort(); }
/// Preprocess request from raylet client. We will check whether the worker is being
/// killed due to job finishing.
///
/// \param worker_id The worker id.
/// \param request_name The request name.
/// \return False if there is no need to process this request.
bool PreprocessRequest(const WorkerID &worker_id, const std::string &request_name);
/// Implementation of node manager grpc service.
/// Handle a `ForwardTask` request.
///
/// \param request The request.
/// \param reply The reply that will be sent to client.
/// \param send_reply_callback Invoke this callback to send reply asynchronously.
void HandleForwardTask(const rpc::ForwardTaskRequest &request,
rpc::ForwardTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
/// Implementation of raylet grpc service handlers, please see definitions
/// in `src/ray/protobuf/raylet.proto` for more details.
void HandleRegisterClientRequest(const rpc::RegisterClientRequest &request,
rpc::RegisterClientReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandleSubmitTaskRequest(const rpc::SubmitTaskRequest &request,
rpc::SubmitTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandleDisconnectClientRequest(const rpc::DisconnectClientRequest &request,
rpc::DisconnectClientReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandleGetTaskRequest(const rpc::GetTaskRequest &request, rpc::GetTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandleTaskDoneRequest(const rpc::TaskDoneRequest &request,
rpc::TaskDoneReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandleFetchOrReconstructRequest(
const rpc::FetchOrReconstructRequest &request, rpc::FetchOrReconstructReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandleNotifyUnblockedRequest(const rpc::NotifyUnblockedRequest &request,
rpc::NotifyUnblockedReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandleWaitRequest(const rpc::WaitRequest &request, rpc::WaitReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandlePushErrorRequest(const rpc::PushErrorRequest &request,
rpc::PushErrorReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandlePushProfileEventsRequest(
const rpc::PushProfileEventsRequest &request, rpc::PushProfileEventsReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandleFreeObjectsInStoreRequest(
const rpc::FreeObjectsInStoreRequest &request, rpc::FreeObjectsInStoreReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandlePrepareActorCheckpointRequest(
const rpc::PrepareActorCheckpointRequest &request,
rpc::PrepareActorCheckpointReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandleNotifyActorResumedFromCheckpointRequest(
const rpc::NotifyActorResumedFromCheckpointRequest &request,
rpc::NotifyActorResumedFromCheckpointReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandleSetResourceRequest(const rpc::SetResourceRequest &request,
rpc::SetResourceReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
void HandleHeartbeatRequest(const rpc::HeartbeatRequest &request,
rpc::HeartbeatReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
private:
/// Methods for handling clients.
@@ -276,11 +208,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
/// \return Void.
void SubmitTask(const Task &task, const Lineage &uncommitted_lineage,
bool forwarded = false);
/// Handle the case that a worker is available.
///
/// \param id Id of the worker.
/// \return Void.
void HandleWorkerAvailable(const WorkerID &worker_id);
/// Assign a task. The task is assumed to not be queued in local_queues_.
///
/// \param task The task in question.
@@ -392,7 +319,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
/// \param ray_get Whether the task is blocked in a `ray.get` call, as
/// opposed to a `ray.wait` call.
/// \return Void.
void HandleTaskBlocked(const WorkerID &worker_id,
void HandleTaskBlocked(const std::shared_ptr<LocalClientConnection> &client,
const std::vector<ObjectID> &required_object_ids,
const TaskID &current_task_id, bool ray_get);
@@ -405,7 +332,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
/// \param client The client that is executing the unblocked task.
/// \param current_task_id The task that is unblocked.
/// \return Void.
void HandleTaskUnblocked(const WorkerID &worker_id, const TaskID &current_task_id);
void HandleTaskUnblocked(const std::shared_ptr<LocalClientConnection> &client,
const TaskID &current_task_id);
/// Kill a worker.
///
@@ -481,8 +409,9 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
/// \param client The client that sent the message.
/// \param intentional_disconnect Whether the client was intentionally disconnected.
/// \return Void.
void ProcessDisconnectClientMessage(const WorkerID &worker_id,
bool intentional_disconnect = false);
void ProcessDisconnectClientMessage(
const std::shared_ptr<LocalClientConnection> &client,
bool intentional_disconnect = false);
/// Process client message of SubmitTask
///
@@ -550,6 +479,19 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
void HandleDisconnectedActor(const ActorID &actor_id, bool was_local,
bool intentional_disconnect);
/// Finish assigning a task to a worker.
///
/// \param task_id Id of the task.
/// \param worker Worker which the task is assigned to.
/// \param success Whether the task is successfully assigned to the worker.
/// \return void.
void FinishAssignTask(const TaskID &task_id, Worker &worker, bool success);
/// Handle a `ForwardTask` request.
void HandleForwardTask(const rpc::ForwardTaskRequest &request,
rpc::ForwardTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) override;
// GCS client ID for this node.
ClientID client_id_;
boost::asio::io_service &io_service_;
@@ -609,9 +551,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
/// The node manager RPC service.
rpc::NodeManagerGrpcService node_manager_service_;
/// The raylet RPC service.
rpc::RayletGrpcService raylet_service_;
/// The `ClientCallManager` object that is shared by all `NodeManagerClient`s
/// as well as all `WorkerTaskClient`s.
rpc::ClientCallManager client_call_manager_;
+57 -4
View File
@@ -7,6 +7,34 @@
#include "ray/common/status.h"
namespace {
const std::vector<std::string> GenerateEnumNames(const char *const *enum_names_ptr,
int start_index, int end_index) {
std::vector<std::string> enum_names;
for (int i = 0; i < start_index; ++i) {
enum_names.push_back("EmptyMessageType");
}
size_t i = 0;
while (true) {
const char *name = enum_names_ptr[i];
if (name == nullptr) {
break;
}
enum_names.push_back(name);
i++;
}
RAY_CHECK(static_cast<size_t>(end_index) == enum_names.size() - 1)
<< "Message Type mismatch!";
return enum_names;
}
static const std::vector<std::string> node_manager_message_enum =
GenerateEnumNames(ray::protocol::EnumNamesMessageType(),
static_cast<int>(ray::protocol::MessageType::MIN),
static_cast<int>(ray::protocol::MessageType::MAX));
} // namespace
namespace ray {
namespace raylet {
@@ -23,10 +51,10 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_
node_manager_(main_service, node_manager_config, object_manager_, gcs_client_,
object_directory_),
socket_name_(socket_name),
raylet_server_("Raylet", socket_name),
raylet_service_(main_service, node_manager_) {
raylet_server_.RegisterService(raylet_service_);
raylet_server_.Run();
acceptor_(main_service, boost::asio::local::stream_protocol::endpoint(socket_name)),
socket_(main_service) {
// Start listening for clients.
DoAccept();
RAY_CHECK_OK(RegisterGcs(
node_ip_address, socket_name_, object_manager_config.store_socket_name,
@@ -80,6 +108,31 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address,
return Status::OK();
}
void Raylet::DoAccept() {
acceptor_.async_accept(socket_, boost::bind(&Raylet::HandleAccept, this,
boost::asio::placeholders::error));
}
void Raylet::HandleAccept(const boost::system::error_code &error) {
if (!error) {
// TODO: typedef these handlers.
ClientHandler<boost::asio::local::stream_protocol> client_handler =
[this](LocalClientConnection &client) { node_manager_.ProcessNewClient(client); };
MessageHandler<boost::asio::local::stream_protocol> message_handler =
[this](std::shared_ptr<LocalClientConnection> client, int64_t message_type,
const uint8_t *message) {
node_manager_.ProcessClientMessage(client, message_type, message);
};
// Accept a new local client and dispatch it to the node manager.
auto new_connection = LocalClientConnection::Create(
client_handler, message_handler, std::move(socket_), "worker",
node_manager_message_enum,
static_cast<int64_t>(protocol::MessageType::DisconnectClient));
}
// We're ready to accept another client.
DoAccept();
}
} // namespace raylet
} // namespace ray
+4 -6
View File
@@ -75,12 +75,10 @@ class Raylet {
/// The name of the socket this raylet listens on.
std::string socket_name_;
/// The gRPC server, listens for local raylet client connections through a unix domain
/// socket.
rpc::GrpcServer raylet_server_;
/// The gRPC service.
rpc::RayletGrpcService raylet_service_;
/// An acceptor for new clients.
boost::asio::local::stream_protocol::acceptor acceptor_;
/// The socket to listen on for new clients.
boost::asio::local::stream_protocol::socket socket_;
};
} // namespace raylet
+392
View File
@@ -0,0 +1,392 @@
#include "raylet_client.h"
#include <inttypes.h>
#include <netdb.h>
#include <netinet/in.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>
#include "ray/common/common_protocol.h"
#include "ray/common/ray_config.h"
#include "ray/common/task/task_spec.h"
#include "ray/raylet/format/node_manager_generated.h"
#include "ray/util/logging.h"
using MessageType = ray::protocol::MessageType;
// TODO(rkn): The io methods below should be removed.
int connect_ipc_sock(const std::string &socket_pathname) {
struct sockaddr_un socket_address;
int socket_fd;
socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
if (socket_fd < 0) {
RAY_LOG(ERROR) << "socket() failed for pathname " << socket_pathname;
return -1;
}
memset(&socket_address, 0, sizeof(socket_address));
socket_address.sun_family = AF_UNIX;
if (socket_pathname.length() + 1 > sizeof(socket_address.sun_path)) {
RAY_LOG(ERROR) << "Socket pathname is too long.";
close(socket_fd);
return -1;
}
strncpy(socket_address.sun_path, socket_pathname.c_str(), socket_pathname.length() + 1);
if (connect(socket_fd, (struct sockaddr *)&socket_address, sizeof(socket_address)) !=
0) {
close(socket_fd);
return -1;
}
return socket_fd;
}
int read_bytes(int socket_fd, uint8_t *cursor, size_t length) {
ssize_t nbytes = 0;
// Termination condition: EOF or read 'length' bytes total.
size_t bytesleft = length;
size_t offset = 0;
while (bytesleft > 0) {
nbytes = read(socket_fd, cursor + offset, bytesleft);
if (nbytes < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
continue;
}
return -1; // Errno will be set.
} else if (0 == nbytes) {
// Encountered early EOF.
return -1;
}
RAY_CHECK(nbytes > 0);
bytesleft -= nbytes;
offset += nbytes;
}
return 0;
}
int write_bytes(int socket_fd, uint8_t *cursor, size_t length) {
ssize_t nbytes = 0;
size_t bytesleft = length;
size_t offset = 0;
while (bytesleft > 0) {
// While we haven't written the whole message, write to the file
// descriptor, advance the cursor, and decrease the amount left to write.
nbytes = write(socket_fd, cursor + offset, bytesleft);
if (nbytes < 0) {
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
continue;
}
return -1; // Errno will be set.
} else if (0 == nbytes) {
// Encountered early EOF.
return -1;
}
RAY_CHECK(nbytes > 0);
bytesleft -= nbytes;
offset += nbytes;
}
return 0;
}
RayletConnection::RayletConnection(const std::string &raylet_socket, int num_retries,
int64_t timeout) {
// Pick the default values if the user did not specify.
if (num_retries < 0) {
num_retries = RayConfig::instance().num_connect_attempts();
}
if (timeout < 0) {
timeout = RayConfig::instance().connect_timeout_milliseconds();
}
RAY_CHECK(!raylet_socket.empty());
conn_ = -1;
for (int num_attempts = 0; num_attempts < num_retries; ++num_attempts) {
conn_ = connect_ipc_sock(raylet_socket);
if (conn_ >= 0) break;
if (num_attempts > 0) {
RAY_LOG(ERROR) << "Retrying to connect to socket for pathname " << raylet_socket
<< " (num_attempts = " << num_attempts
<< ", num_retries = " << num_retries << ")";
}
// Sleep for timeout milliseconds.
usleep(timeout * 1000);
}
// If we could not connect to the socket, exit.
if (conn_ == -1) {
RAY_LOG(FATAL) << "Could not connect to socket " << raylet_socket;
}
}
ray::Status RayletConnection::Disconnect() {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateDisconnectClient(fbb);
fbb.Finish(message);
auto status = WriteMessage(MessageType::IntentionalDisconnectClient, &fbb);
// Don't be too strict for disconnection errors.
// Just create logs and prevent it from crash.
if (!status.ok()) {
RAY_LOG(ERROR) << status.ToString()
<< " [RayletClient] Failed to disconnect from raylet.";
}
return ray::Status::OK();
}
ray::Status RayletConnection::ReadMessage(MessageType type,
std::unique_ptr<uint8_t[]> &message) {
int64_t cookie;
int64_t type_field;
int64_t length;
int closed = read_bytes(conn_, (uint8_t *)&cookie, sizeof(cookie));
if (closed) goto disconnected;
RAY_CHECK(cookie == RayConfig::instance().ray_cookie());
closed = read_bytes(conn_, (uint8_t *)&type_field, sizeof(type_field));
if (closed) goto disconnected;
closed = read_bytes(conn_, (uint8_t *)&length, sizeof(length));
if (closed) goto disconnected;
message = std::unique_ptr<uint8_t[]>(new uint8_t[length]);
closed = read_bytes(conn_, message.get(), length);
if (closed) {
// Handle the case in which the socket is closed.
message.reset(nullptr);
disconnected:
message = nullptr;
type_field = static_cast<int64_t>(MessageType::DisconnectClient);
length = 0;
}
if (type_field == static_cast<int64_t>(MessageType::DisconnectClient)) {
return ray::Status::IOError("[RayletClient] Raylet connection closed.");
}
if (type_field != static_cast<int64_t>(type)) {
return ray::Status::TypeError(
std::string("[RayletClient] Raylet connection corrupted. ") +
"Expected message type: " + std::to_string(static_cast<int64_t>(type)) +
"; got message type: " + std::to_string(type_field) +
". Check logs or dmesg for previous errors.");
}
return ray::Status::OK();
}
ray::Status RayletConnection::WriteMessage(MessageType type,
flatbuffers::FlatBufferBuilder *fbb) {
std::unique_lock<std::mutex> guard(write_mutex_);
int64_t cookie = RayConfig::instance().ray_cookie();
int64_t length = fbb ? fbb->GetSize() : 0;
uint8_t *bytes = fbb ? fbb->GetBufferPointer() : nullptr;
int64_t type_field = static_cast<int64_t>(type);
auto io_error = ray::Status::IOError("[RayletClient] Connection closed unexpectedly.");
int closed;
closed = write_bytes(conn_, (uint8_t *)&cookie, sizeof(cookie));
if (closed) return io_error;
closed = write_bytes(conn_, (uint8_t *)&type_field, sizeof(type_field));
if (closed) return io_error;
closed = write_bytes(conn_, (uint8_t *)&length, sizeof(length));
if (closed) return io_error;
closed = write_bytes(conn_, bytes, length * sizeof(char));
if (closed) return io_error;
return ray::Status::OK();
}
ray::Status RayletConnection::AtomicRequestReply(
MessageType request_type, MessageType reply_type,
std::unique_ptr<uint8_t[]> &reply_message, flatbuffers::FlatBufferBuilder *fbb) {
std::unique_lock<std::mutex> guard(mutex_);
auto status = WriteMessage(request_type, fbb);
if (!status.ok()) return status;
return ReadMessage(reply_type, reply_message);
}
RayletClient::RayletClient(const std::string &raylet_socket, const WorkerID &worker_id,
bool is_worker, const JobID &job_id, const Language &language,
int port)
: worker_id_(worker_id), is_worker_(is_worker), job_id_(job_id), language_(language) {
// For C++14, we could use std::make_unique
conn_ = std::unique_ptr<RayletConnection>(new RayletConnection(raylet_socket, -1, -1));
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateRegisterClientRequest(
fbb, is_worker, to_flatbuf(fbb, worker_id), getpid(), to_flatbuf(fbb, job_id),
language, port);
fbb.Finish(message);
// Register the process ID with the raylet.
// NOTE(swang): If raylet exits and we are registered as a worker, we will get killed.
auto status = conn_->WriteMessage(MessageType::RegisterClientRequest, &fbb);
RAY_CHECK_OK_PREPEND(status, "[RayletClient] Unable to register worker with raylet.");
}
ray::Status RayletClient::SubmitTask(const ray::TaskSpecification &task_spec) {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateSubmitTaskRequest(
fbb, fbb.CreateString(task_spec.Serialize()));
fbb.Finish(message);
return conn_->WriteMessage(MessageType::SubmitTask, &fbb);
}
ray::Status RayletClient::GetTask(std::unique_ptr<ray::TaskSpecification> *task_spec) {
std::unique_ptr<uint8_t[]> reply;
// Receive a task from the raylet. This will block until the raylet
// gives this client a task.
auto status =
conn_->AtomicRequestReply(MessageType::GetTask, MessageType::ExecuteTask, reply);
if (!status.ok()) return status;
// Parse the flatbuffer object.
auto reply_message = flatbuffers::GetRoot<ray::protocol::GetTaskReply>(reply.get());
// Set the resource IDs for this task.
resource_ids_.clear();
for (size_t i = 0; i < reply_message->fractional_resource_ids()->size(); ++i) {
auto const &fractional_resource_ids =
reply_message->fractional_resource_ids()->Get(i);
auto &acquired_resources =
resource_ids_[string_from_flatbuf(*fractional_resource_ids->resource_name())];
size_t num_resource_ids = fractional_resource_ids->resource_ids()->size();
size_t num_resource_fractions = fractional_resource_ids->resource_fractions()->size();
RAY_CHECK(num_resource_ids == num_resource_fractions);
RAY_CHECK(num_resource_ids > 0);
for (size_t j = 0; j < num_resource_ids; ++j) {
int64_t resource_id = fractional_resource_ids->resource_ids()->Get(j);
double resource_fraction = fractional_resource_ids->resource_fractions()->Get(j);
if (num_resource_ids > 1) {
int64_t whole_fraction = resource_fraction;
RAY_CHECK(whole_fraction == resource_fraction);
}
acquired_resources.push_back(std::make_pair(resource_id, resource_fraction));
}
}
// Return the copy of the task spec and pass ownership to the caller.
task_spec->reset(
new ray::TaskSpecification(string_from_flatbuf(*reply_message->task_spec())));
return ray::Status::OK();
}
ray::Status RayletClient::TaskDone() {
return conn_->WriteMessage(MessageType::TaskDone);
}
ray::Status RayletClient::FetchOrReconstruct(const std::vector<ObjectID> &object_ids,
bool fetch_only,
const TaskID &current_task_id) {
flatbuffers::FlatBufferBuilder fbb;
auto object_ids_message = to_flatbuf(fbb, object_ids);
auto message = ray::protocol::CreateFetchOrReconstruct(
fbb, object_ids_message, fetch_only, to_flatbuf(fbb, current_task_id));
fbb.Finish(message);
auto status = conn_->WriteMessage(MessageType::FetchOrReconstruct, &fbb);
return status;
}
ray::Status RayletClient::NotifyUnblocked(const TaskID &current_task_id) {
flatbuffers::FlatBufferBuilder fbb;
auto message =
ray::protocol::CreateNotifyUnblocked(fbb, to_flatbuf(fbb, current_task_id));
fbb.Finish(message);
return conn_->WriteMessage(MessageType::NotifyUnblocked, &fbb);
}
ray::Status RayletClient::Wait(const std::vector<ObjectID> &object_ids, int num_returns,
int64_t timeout_milliseconds, bool wait_local,
const TaskID &current_task_id, WaitResultPair *result) {
// Write request.
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateWaitRequest(
fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds, wait_local,
to_flatbuf(fbb, current_task_id));
fbb.Finish(message);
std::unique_ptr<uint8_t[]> reply;
auto status = conn_->AtomicRequestReply(MessageType::WaitRequest,
MessageType::WaitReply, reply, &fbb);
if (!status.ok()) return status;
// Parse the flatbuffer object.
auto reply_message = flatbuffers::GetRoot<ray::protocol::WaitReply>(reply.get());
auto found = reply_message->found();
for (uint i = 0; i < found->size(); i++) {
ObjectID object_id = ObjectID::FromBinary(found->Get(i)->str());
result->first.push_back(object_id);
}
auto remaining = reply_message->remaining();
for (uint i = 0; i < remaining->size(); i++) {
ObjectID object_id = ObjectID::FromBinary(remaining->Get(i)->str());
result->second.push_back(object_id);
}
return ray::Status::OK();
}
ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string &type,
const std::string &error_message, double timestamp) {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreatePushErrorRequest(
fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type),
fbb.CreateString(error_message), timestamp);
fbb.Finish(message);
return conn_->WriteMessage(MessageType::PushErrorRequest, &fbb);
}
ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_events) {
flatbuffers::FlatBufferBuilder fbb;
auto message = fbb.CreateString(profile_events.SerializeAsString());
fbb.Finish(message);
auto status = conn_->WriteMessage(MessageType::PushProfileEventsRequest, &fbb);
// Don't be too strict for profile errors. Just create logs and prevent it from crash.
if (!status.ok()) {
RAY_LOG(ERROR) << status.ToString()
<< " [RayletClient] Failed to push profile events.";
}
return ray::Status::OK();
}
ray::Status RayletClient::FreeObjects(const std::vector<ray::ObjectID> &object_ids,
bool local_only, bool delete_creating_tasks) {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateFreeObjectsRequest(
fbb, local_only, delete_creating_tasks, to_flatbuf(fbb, object_ids));
fbb.Finish(message);
auto status = conn_->WriteMessage(MessageType::FreeObjectsInObjectStoreRequest, &fbb);
return status;
}
ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id,
ActorCheckpointID &checkpoint_id) {
flatbuffers::FlatBufferBuilder fbb;
auto message =
ray::protocol::CreatePrepareActorCheckpointRequest(fbb, to_flatbuf(fbb, actor_id));
fbb.Finish(message);
std::unique_ptr<uint8_t[]> reply;
auto status =
conn_->AtomicRequestReply(MessageType::PrepareActorCheckpointRequest,
MessageType::PrepareActorCheckpointReply, reply, &fbb);
if (!status.ok()) return status;
auto reply_message =
flatbuffers::GetRoot<ray::protocol::PrepareActorCheckpointReply>(reply.get());
checkpoint_id = ActorCheckpointID::FromBinary(reply_message->checkpoint_id()->str());
return ray::Status::OK();
}
ray::Status RayletClient::NotifyActorResumedFromCheckpoint(
const ActorID &actor_id, const ActorCheckpointID &checkpoint_id) {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateNotifyActorResumedFromCheckpoint(
fbb, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, checkpoint_id));
fbb.Finish(message);
return conn_->WriteMessage(MessageType::NotifyActorResumedFromCheckpoint, &fbb);
}
ray::Status RayletClient::SetResource(const std::string &resource_name,
const double capacity,
const ray::ClientID &client_Id) {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateSetResourceRequest(
fbb, fbb.CreateString(resource_name), capacity, to_flatbuf(fbb, client_Id));
fbb.Finish(message);
return conn_->WriteMessage(MessageType::SetResourceRequest, &fbb);
}
@@ -1,26 +1,18 @@
#ifndef RAY_RPC_RAYLET_CLIENT_H
#define RAY_RPC_RAYLET_CLIENT_H
#ifndef RAYLET_CLIENT_H
#define RAYLET_CLIENT_H
#include <ray/common/ray_config.h>
#include <ray/protobuf/gcs.pb.h>
#include <unistd.h>
#include <future>
#include <mutex>
#include <thread>
#include <unordered_map>
#include <vector>
#include <grpcpp/grpcpp.h>
#include "ray/common/status.h"
#include "ray/common/task/task_spec.h"
#include "src/ray/common/status.h"
#include "src/ray/protobuf/raylet.grpc.pb.h"
#include "src/ray/protobuf/raylet.pb.h"
#include "src/ray/rpc/client_call.h"
using ray::ActorCheckpointID;
using ray::ActorID;
using ray::ClientID;
using ray::JobID;
using ray::ObjectID;
using ray::TaskID;
@@ -28,39 +20,65 @@ using ray::WorkerID;
using ray::Language;
using ray::rpc::ProfileTableData;
using WaitResultPair = std::pair<std::vector<ObjectID>, std::vector<ObjectID>>;
namespace ray {
namespace rpc {
using MessageType = ray::protocol::MessageType;
using ResourceMappingType =
std::unordered_map<std::string, std::vector<std::pair<int64_t, double>>>;
using WaitResultPair = std::pair<std::vector<ObjectID>, std::vector<ObjectID>>;
class RayletConnection {
public:
/// Connect to the raylet.
///
/// \param raylet_socket The name of the socket to use to connect to the raylet.
/// \param worker_id A unique ID to represent the worker.
/// \param is_worker Whether this client is a worker. If it is a worker, an
/// additional message will be sent to register as one.
/// \param job_id The ID of the driver. This is non-nil if the client is a
/// driver.
/// \return The connection information.
RayletConnection(const std::string &raylet_socket, int num_retries, int64_t timeout);
~RayletConnection() { close(conn_); }
/// Notify the raylet that this client is disconnecting gracefully. This
/// is used by actors to exit gracefully so that the raylet doesn't
/// propagate an error message to the driver.
///
/// \return ray::Status.
ray::Status Disconnect();
ray::Status ReadMessage(MessageType type, std::unique_ptr<uint8_t[]> &message);
ray::Status WriteMessage(MessageType type,
flatbuffers::FlatBufferBuilder *fbb = nullptr);
ray::Status AtomicRequestReply(MessageType request_type, MessageType reply_type,
std::unique_ptr<uint8_t[]> &reply_message,
flatbuffers::FlatBufferBuilder *fbb = nullptr);
private:
/// File descriptor of the Unix domain socket that connects to raylet.
int conn_;
/// A mutex to protect stateful operations of the raylet client.
std::mutex mutex_;
/// A mutex to protect write operations of the raylet client.
std::mutex write_mutex_;
};
/// Client used for communicating with the raylet.
class RayletClient {
public:
/// Constructor for the raylet client.
/// TODO(jzh): At present, client call manager and reply handler service are generated
/// in raylet client. Instead, we should add parameters to the constructor and pass them
/// in. Change them as input parameters once we changed the worker into a server.
/// Connect to the raylet.
///
/// \param[in] raylet_socket Unix domain socket of the raylet server.
/// \param[in] worker_id The worker id.
/// \param[in] is_worker Indicates whether a worker or a driver.
/// \param[in] job_id The job id that this raylet client belongs to.
/// \param[in] language The language type, python or java.
/// \param[in] port The listening port of the worker server, -1 means that the worker
/// does not have a server.
/// \param raylet_socket The name of the socket to use to connect to the raylet.
/// \param worker_id A unique ID to represent the worker.
/// \param is_worker Whether this client is a worker. If it is a worker, an
/// additional message will be sent to register as one.
/// \param job_id The ID of the driver. This is non-nil if the client is a driver.
/// \return The connection information.
RayletClient(const std::string &raylet_socket, const WorkerID &worker_id,
bool is_worker, const JobID &job_id, const Language &language,
int port = -1);
~RayletClient();
ray::Status Disconnect() { return conn_->Disconnect(); };
/// Send disconnect request to local raylet.
ray::Status Disconnect();
/// Submit a task to the local raylet.
/// Submit a task using the raylet code path.
///
/// \param The task specification.
/// \return ray::Status.
@@ -160,7 +178,7 @@ class RayletClient {
Language GetLanguage() const { return language_; }
WorkerID GetWorkerId() const { return worker_id_; }
WorkerID GetWorkerID() const { return worker_id_; }
JobID GetJobID() const { return job_id_; }
@@ -169,53 +187,16 @@ class RayletClient {
const ResourceMappingType &GetResourceIDs() const { return resource_ids_; }
private:
/// Try to register client in raylet, we would retry serveral time to
/// reconnect if failed. We need this because raylet client may start before raylet
/// server.
///
/// \param times Number of times to retry.
void TryRegisterClient(int retry_times);
ray::Status RegisterClient();
/// Send heartbeat requests to the raylet server.
void Heartbeat();
/// Id of the worker to which this raylet client belongs.
const WorkerID worker_id_;
/// Indicates whether this worker is a driver worker.
/// Driver is treated as a special worker.
const bool is_worker_;
const JobID job_id_;
const Language language_;
const int port_;
/// A map from resource name to the resource IDs that are currently reserved
/// for this worker. Each pair consists of the resource ID and the fraction
/// of that resource allocated for this worker.
ResourceMappingType resource_ids_;
/// The gRPC-generated stub.
std::unique_ptr<RayletService::Stub> stub_;
/// Service for handling reply.
boost::asio::io_service main_service_;
/// Asio work for main service.
boost::asio::io_service::work work_;
/// The `ClientCallManager` used for managing requests.
ClientCallManager client_call_manager_;
/// The thread used to handle reply.
std::thread rpc_thread_;
/// Heartbeat timer.
boost::asio::deadline_timer heartbeat_timer_;
/// Indicates whether the connection has been closed.
bool is_connected_;
/// The connection to the raylet server.
std::unique_ptr<RayletConnection> conn_;
};
} // namespace rpc
} // namespace ray
#endif // RAY_RPC_RAYLET_CLIENT_H
#endif
+1 -1
View File
@@ -5,6 +5,7 @@
#include <boost/asio.hpp>
#include "ray/raylet/format/node_manager_generated.h"
#include "ray/raylet/reconstruction_policy.h"
#include "ray/object_manager/object_directory.h"
@@ -417,7 +418,6 @@ TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) {
auto task_reconstruction_data = std::make_shared<TaskReconstructionData>();
task_reconstruction_data->set_node_manager_id(ClientID::FromRandom().Binary());
task_reconstruction_data->set_num_reconstructions(0);
RAY_CHECK_OK(
mock_gcs_.AppendAt(JobID::Nil(), task_id, task_reconstruction_data, nullptr,
/*failure_callback=*/
-1
View File
@@ -3,7 +3,6 @@
// clang-format off
#include "ray/common/id.h"
#include "ray/common/ray_config.h"
#include "ray/common/task/task.h"
#include "ray/object_manager/object_manager.h"
#include "ray/raylet/reconstruction_policy.h"
+28 -36
View File
@@ -2,6 +2,7 @@
#include <boost/bind.hpp>
#include "ray/raylet/format/node_manager_generated.h"
#include "ray/raylet/raylet.h"
namespace ray {
@@ -10,27 +11,25 @@ namespace raylet {
/// A constructor responsible for initializing the state of a worker.
Worker::Worker(const WorkerID &worker_id, pid_t pid, const Language &language, int port,
rpc::ClientCallManager &client_call_manager, bool is_worker)
std::shared_ptr<LocalClientConnection> connection,
rpc::ClientCallManager &client_call_manager)
: worker_id_(worker_id),
pid_(pid),
port_(port),
language_(language),
port_(port),
connection_(connection),
dead_(false),
blocked_(false),
num_missed_heartbeats_(0),
is_being_killed_(false),
client_call_manager_(client_call_manager),
is_worker_(is_worker) {
client_call_manager_(client_call_manager) {
if (port_ > 0) {
rpc_client_ = std::unique_ptr<rpc::WorkerTaskClient>(
new rpc::WorkerTaskClient("127.0.0.1", port_, client_call_manager_));
}
}
void Worker::MarkAsBeingKilled() { is_being_killed_ = true; }
void Worker::MarkDead() { dead_ = true; }
bool Worker::IsBeingKilled() const { return is_being_killed_; }
bool Worker::IsWorker() const { return is_worker_; }
bool Worker::IsDead() const { return dead_; }
void Worker::MarkBlocked() { blocked_ = true; }
@@ -44,8 +43,6 @@ pid_t Worker::Pid() const { return pid_; }
Language Worker::GetLanguage() const { return language_; }
const WorkerID &Worker::GetWorkerId() const { return worker_id_; }
int Worker::Port() const { return port_; }
void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id; }
@@ -79,6 +76,10 @@ void Worker::AssignActorId(const ActorID &actor_id) {
const ActorID &Worker::GetActorId() const { return actor_id_; }
const std::shared_ptr<LocalClientConnection> Worker::Connection() const {
return connection_;
}
const ResourceIdSet &Worker::GetLifetimeResourceIds() const {
return lifetime_resource_ids_;
}
@@ -112,16 +113,10 @@ void Worker::AcquireTaskCpuResources(const ResourceIdSet &cpu_resources) {
task_resource_ids_.Release(cpu_resources);
}
void Worker::SetGetTaskReplyAndCallback(
rpc::GetTaskReply *reply, const rpc::SendReplyCallback &&send_reply_callback) {
RAY_CHECK(reply_ == nullptr && send_reply_callback_ == nullptr);
reply_ = reply;
send_reply_callback_ = std::move(send_reply_callback);
}
bool Worker::UsePush() const { return rpc_client_ != nullptr; }
void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set) {
void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set,
const std::function<void(Status)> finish_assign_callback) {
const TaskSpecification &spec = task.GetTaskSpecification();
if (rpc_client_ != nullptr) {
// Use push mode.
@@ -131,16 +126,15 @@ void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set)
task.GetTaskSpecification().GetMessage());
request.mutable_task()->mutable_task_execution_spec()->CopyFrom(
task.GetTaskExecutionSpec().GetMessage());
for (const auto &e : resource_id_set.ToProtobuf()) {
auto resource = request.add_resource_ids();
*resource = e;
}
request.set_resource_ids(resource_id_set.Serialize());
auto status = rpc_client_->AssignTask(
request, [](Status status, const rpc::AssignTaskReply &reply) {
// Worker has finished this task. There's nothing to do here
// and assigning new task will be done when raylet receives
// `TaskDone` message.
});
finish_assign_callback(status);
if (!status.ok()) {
RAY_LOG(ERROR) << "Failed to assign task " << task.GetTaskSpecification().TaskId()
<< " to worker " << worker_id_;
@@ -151,19 +145,17 @@ void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set)
} else {
// Use pull mode. This corresponds to existing python/java workers that haven't been
// migrated to core worker architecture.
RAY_CHECK(reply_ != nullptr && send_reply_callback_ != nullptr);
reply_->set_task_spec(task.GetTaskSpecification().Serialize());
for (const auto &e : resource_id_set.ToProtobuf()) {
auto resource = reply_->add_fractional_resource_ids();
*resource = e;
}
send_reply_callback_(Status::OK(), nullptr, nullptr);
reply_ = nullptr;
send_reply_callback_ = nullptr;
flatbuffers::FlatBufferBuilder fbb;
auto resource_id_set_flatbuf = resource_id_set.ToFlatbuf(fbb);
auto message =
protocol::CreateGetTaskReply(fbb, fbb.CreateString(spec.Serialize()),
fbb.CreateVector(resource_id_set_flatbuf));
fbb.Finish(message);
Connection()->WriteMessageAsync(
static_cast<int64_t>(protocol::MessageType::ExecuteTask), fbb.GetSize(),
fbb.GetBufferPointer(), finish_assign_callback);
}
// The status will be cleared when the worker dies.
AssignTaskId(spec.TaskId());
AssignJobId(spec.JobId());
}
} // namespace raylet
+15 -33
View File
@@ -3,15 +3,12 @@
#include <memory>
#include "ray/common/client_connection.h"
#include "ray/common/id.h"
#include "ray/common/task/scheduling_resources.h"
#include "ray/common/task/task.h"
#include "ray/common/task/task_common.h"
#include "ray/protobuf/common.pb.h"
#include "ray/rpc/worker/worker_client.h"
#include "src/ray/protobuf/gcs.pb.h"
#include "src/ray/protobuf/raylet.pb.h"
#include "src/ray/rpc/server_call.h"
namespace ray {
@@ -24,12 +21,12 @@ class Worker {
public:
/// A constructor that initializes a worker object.
Worker(const WorkerID &worker_id, pid_t pid, const Language &language, int port,
rpc::ClientCallManager &client_call_manager, bool is_worker = true);
std::shared_ptr<LocalClientConnection> connection,
rpc::ClientCallManager &client_call_manager);
/// A destructor responsible for freeing all worker state.
~Worker() {}
void MarkAsBeingKilled();
bool IsBeingKilled() const;
bool IsWorker() const;
void MarkDead();
bool IsDead() const;
void MarkBlocked();
void MarkUnblocked();
bool IsBlocked() const;
@@ -38,7 +35,6 @@ class Worker {
/// Return the worker's PID.
pid_t Pid() const;
Language GetLanguage() const;
const WorkerID &GetWorkerId() const;
int Port() const;
void AssignTaskId(const TaskID &task_id);
const TaskID &GetAssignedTaskId() const;
@@ -49,6 +45,7 @@ class Worker {
const JobID &GetAssignedJobId() const;
void AssignActorId(const ActorID &actor_id);
const ActorID &GetActorId() const;
const std::shared_ptr<LocalClientConnection> Connection() const;
const ResourceIdSet &GetLifetimeResourceIds() const;
void SetLifetimeResourceIds(ResourceIdSet &resource_ids);
@@ -60,34 +57,30 @@ class Worker {
ResourceIdSet ReleaseTaskCpuResources();
void AcquireTaskCpuResources(const ResourceIdSet &cpu_resources);
int TickHeartbeatTimer() { return ++num_missed_heartbeats_; }
void ClearHeartbeat() { num_missed_heartbeats_ = 0; }
/// When receiving a `GetTask` request from worker, this function should be called to
/// pass the reply and callback of the request to the worker. Later, when we actually
/// assign a task to the worker, the reply will be filled and the callback will be
/// called.
void SetGetTaskReplyAndCallback(rpc::GetTaskReply *reply,
const rpc::SendReplyCallback &&send_reply_callback);
bool UsePush() const;
void AssignTask(const Task &task, const ResourceIdSet &resource_id_set);
void AssignTask(const Task &task, const ResourceIdSet &resource_id_set,
const std::function<void(Status)> finish_assign_callback);
private:
/// The worker's ID.
WorkerID worker_id_;
/// The worker's PID.
pid_t pid_;
/// The worker port.
int port_;
/// The language type of this worker.
Language language_;
/// Port that this worker listens on.
/// If port <= 0, this indicates that the worker will not listen to a port.
int port_;
/// Connection state of a worker.
std::shared_ptr<LocalClientConnection> connection_;
/// The worker's currently assigned task.
TaskID assigned_task_id_;
/// Job ID for the worker's current assigned task.
JobID assigned_job_id_;
/// The worker's actor ID. If this is nil, then the worker is not an actor.
ActorID actor_id_;
/// Whether the worker is dead.
bool dead_;
/// Whether the worker is blocked. Workers become blocked in a `ray.get`, if
/// they require a data dependency while executing a task.
bool blocked_;
@@ -98,22 +91,11 @@ class Worker {
// of a task.
ResourceIdSet task_resource_ids_;
std::unordered_set<TaskID> blocked_task_ids_;
/// How many heartbeats have been missed for this worker.
int num_missed_heartbeats_;
/// Indicates we have sent kill signal to the worker if it's true. We cannot treat the
/// worker process as really dead until we lost the heartbeats from the worker.
bool is_being_killed_;
/// The `ClientCallManager` object that is shared by `WorkerTaskClient` from all
/// workers.
rpc::ClientCallManager &client_call_manager_;
/// Indicates whether this is a worker or a driver.
bool is_worker_;
/// The rpc client to send tasks to this worker.
std::unique_ptr<rpc::WorkerTaskClient> rpc_client_;
/// Reply of the `GetTask` request.
rpc::GetTaskReply *reply_ = nullptr;
/// Callback of the `GetTask` request.
rpc::SendReplyCallback send_reply_callback_ = nullptr;
};
} // namespace raylet
+45 -61
View File
@@ -12,6 +12,29 @@
#include "ray/util/logging.h"
#include "ray/util/util.h"
namespace {
// A helper function to get a worker from a list.
std::shared_ptr<ray::raylet::Worker> GetWorker(
const std::unordered_set<std::shared_ptr<ray::raylet::Worker>> &worker_pool,
const std::shared_ptr<ray::LocalClientConnection> &connection) {
for (auto it = worker_pool.begin(); it != worker_pool.end(); it++) {
if ((*it)->Connection() == connection) {
return (*it);
}
}
return nullptr;
}
// A helper function to remove a worker from a list. Returns true if the worker
// was found and removed.
bool RemoveWorker(std::unordered_set<std::shared_ptr<ray::raylet::Worker>> &worker_pool,
const std::shared_ptr<ray::raylet::Worker> &worker) {
return worker_pool.erase(worker) > 0;
}
} // namespace
namespace ray {
namespace raylet {
@@ -50,8 +73,8 @@ WorkerPool::~WorkerPool() {
for (const auto &entry : states_by_lang_) {
// Kill all registered workers. NOTE(swang): This assumes that the registered
// workers were started by the pool.
for (const auto &worker_pair : entry.second.registered_workers) {
pids_to_kill.insert(worker_pair.second->Pid());
for (const auto &worker : entry.second.registered_workers) {
pids_to_kill.insert(worker->Pid());
}
// Kill all the workers that have been started but not registered.
for (const auto &starting_worker : entry.second.starting_worker_processes) {
@@ -166,8 +189,7 @@ pid_t WorkerPool::StartProcess(const std::vector<std::string> &worker_command_ar
return 0;
}
Status WorkerPool::RegisterWorker(const WorkerID &worker_id,
const std::shared_ptr<Worker> &worker) {
Status WorkerPool::RegisterWorker(const std::shared_ptr<Worker> &worker) {
const auto pid = worker->Pid();
const auto port = worker->Port();
RAY_LOG(DEBUG) << "Registering worker with pid " << pid << ", port: " << port;
@@ -183,35 +205,34 @@ Status WorkerPool::RegisterWorker(const WorkerID &worker_id,
state.starting_worker_processes.erase(it);
}
state.registered_workers.emplace(worker_id, std::move(worker));
state.registered_workers.emplace(std::move(worker));
return Status::OK();
}
Status WorkerPool::RegisterDriver(const WorkerID &driver_id,
const std::shared_ptr<Worker> &driver) {
Status WorkerPool::RegisterDriver(const std::shared_ptr<Worker> &driver) {
RAY_CHECK(!driver->GetAssignedTaskId().IsNil());
auto &state = GetStateForLanguage(driver->GetLanguage());
state.registered_drivers.emplace(driver_id, std::move(driver));
state.registered_drivers.insert(std::move(driver));
return Status::OK();
}
std::shared_ptr<Worker> WorkerPool::GetRegisteredWorker(const WorkerID &worker_id) const {
std::shared_ptr<Worker> WorkerPool::GetRegisteredWorker(
const std::shared_ptr<LocalClientConnection> &connection) const {
for (const auto &entry : states_by_lang_) {
auto &registered_workers = entry.second.registered_workers;
auto it = registered_workers.find(worker_id);
if (it != registered_workers.end()) {
return it->second;
auto worker = GetWorker(entry.second.registered_workers, connection);
if (worker != nullptr) {
return worker;
}
}
return nullptr;
}
std::shared_ptr<Worker> WorkerPool::GetRegisteredDriver(const WorkerID &worker_id) const {
std::shared_ptr<Worker> WorkerPool::GetRegisteredDriver(
const std::shared_ptr<LocalClientConnection> &connection) const {
for (const auto &entry : states_by_lang_) {
auto &registered_drivers = entry.second.registered_drivers;
auto it = registered_drivers.find(worker_id);
if (it != registered_drivers.end()) {
return it->second;
auto driver = GetWorker(entry.second.registered_drivers, connection);
if (driver != nullptr) {
return driver;
}
}
return nullptr;
@@ -295,20 +316,18 @@ std::shared_ptr<Worker> WorkerPool::PopWorker(const TaskSpecification &task_spec
bool WorkerPool::DisconnectWorker(const std::shared_ptr<Worker> &worker) {
auto &state = GetStateForLanguage(worker->GetLanguage());
RAY_CHECK(state.registered_workers.erase(worker->GetWorkerId()));
RAY_CHECK(RemoveWorker(state.registered_workers, worker));
stats::CurrentWorker().Record(
0, {{stats::LanguageKey, Language_Name(worker->GetLanguage())},
{stats::WorkerPidKey, std::to_string(worker->Pid())}});
// Indicates that we disconnect a idle worker successfully.
return (state.idle.erase(worker) > 0);
return RemoveWorker(state.idle, worker);
}
void WorkerPool::DisconnectDriver(const std::shared_ptr<Worker> &driver) {
auto &state = GetStateForLanguage(driver->GetLanguage());
RAY_CHECK(state.registered_drivers.erase(driver->GetWorkerId()));
RAY_CHECK(RemoveWorker(state.registered_drivers, driver));
stats::CurrentDriver().Record(
0, {{stats::LanguageKey, Language_Name(driver->GetLanguage())},
{stats::WorkerPidKey, std::to_string(driver->Pid())}});
@@ -325,8 +344,7 @@ std::vector<std::shared_ptr<Worker>> WorkerPool::GetWorkersRunningTasksForJob(
std::vector<std::shared_ptr<Worker>> workers;
for (const auto &entry : states_by_lang_) {
for (const auto &worker_pair : entry.second.registered_workers) {
auto &worker = worker_pair.second;
for (const auto &worker : entry.second.registered_workers) {
if (worker->GetAssignedJobId() == job_id) {
workers.push_back(worker);
}
@@ -381,16 +399,14 @@ std::string WorkerPool::DebugString() const {
void WorkerPool::RecordMetrics() const {
for (const auto &entry : states_by_lang_) {
// Record worker.
for (auto worker_pair : entry.second.registered_workers) {
auto &worker = worker_pair.second;
for (auto worker : entry.second.registered_workers) {
stats::CurrentWorker().Record(
worker->Pid(), {{stats::LanguageKey, Language_Name(worker->GetLanguage())},
{stats::WorkerPidKey, std::to_string(worker->Pid())}});
}
// Record driver.
for (auto driver_pair : entry.second.registered_drivers) {
auto &driver = driver_pair.second;
for (auto driver : entry.second.registered_drivers) {
stats::CurrentDriver().Record(
driver->Pid(), {{stats::LanguageKey, Language_Name(driver->GetLanguage())},
{stats::WorkerPidKey, std::to_string(driver->Pid())}});
@@ -398,38 +414,6 @@ void WorkerPool::RecordMetrics() const {
}
}
void WorkerPool::TickHeartbeatTimer(int max_missed_heartbeats,
std::vector<std::shared_ptr<Worker>> *dead_workers) {
for (const auto &entry : states_by_lang_) {
// Worker heartbeat.
for (const auto &worker_pair : entry.second.registered_workers) {
auto &worker = worker_pair.second;
if (worker->TickHeartbeatTimer() >= max_missed_heartbeats) {
dead_workers->emplace_back(worker);
}
}
// Driver heartbeat.
for (const auto &driver_pair : entry.second.registered_drivers) {
auto &driver = driver_pair.second;
if (driver->TickHeartbeatTimer() >= max_missed_heartbeats) {
dead_workers->emplace_back(driver);
}
}
}
}
std::shared_ptr<Worker> WorkerPool::GetWorker(const WorkerID &worker_id) {
auto worker = GetRegisteredWorker(worker_id);
if (!worker) {
worker = GetRegisteredDriver(worker_id);
if (!worker) {
return nullptr;
}
}
return worker;
}
} // namespace raylet
} // namespace ray
+8 -22
View File
@@ -52,28 +52,29 @@ class WorkerPool {
///
/// \param The Worker to be registered.
/// \return If the registration is successful.
Status RegisterWorker(const WorkerID &worker_id, const std::shared_ptr<Worker> &worker);
Status RegisterWorker(const std::shared_ptr<Worker> &worker);
/// Register a new driver.
/// Driver is a treated as a special worker, so use WorkerID as key here.
///
/// \param The driver to be registered.
/// \return If the registration is successful.
Status RegisterDriver(const WorkerID &driver_id, const std::shared_ptr<Worker> &worker);
Status RegisterDriver(const std::shared_ptr<Worker> &worker);
/// Get the client connection's registered worker.
///
/// \param The client connection owned by a registered worker.
/// \return The Worker that owns the given client connection. Returns nullptr
/// if the client has not registered a worker yet.
std::shared_ptr<Worker> GetRegisteredWorker(const WorkerID &worker_id) const;
std::shared_ptr<Worker> GetRegisteredWorker(
const std::shared_ptr<LocalClientConnection> &connection) const;
/// Get the client connection's registered driver.
///
/// \param The client connection owned by a registered driver.
/// \return The Worker that owns the given client connection. Returns nullptr
/// if the client has not registered a driver.
std::shared_ptr<Worker> GetRegisteredDriver(const WorkerID &driver_id) const;
std::shared_ptr<Worker> GetRegisteredDriver(
const std::shared_ptr<LocalClientConnection> &connection) const;
/// Disconnect a registered worker.
///
@@ -130,21 +131,6 @@ class WorkerPool {
/// Record metrics.
void RecordMetrics() const;
/// Tick the heartbeat timer and get the workers that have timed out.
/// A worker which has missed `max_missed_heartbeats` times would be treated as a
/// dead process or the network to it has been down.
///
/// \param[in] max_missed_heartbeats The maximum number of heartbeats that can be
/// missed before a worker times out.
/// \param[out] dead_workers Workers whose processes have been dead.
void TickHeartbeatTimer(int max_missed_heartbeats,
std::vector<std::shared_ptr<Worker>> *dead_workers);
/// Return the pointer to the worker according to the worker id.
///
/// \param worker_id The worker id.
std::shared_ptr<Worker> GetWorker(const WorkerID &worker_id);
protected:
/// Asynchronously start a new worker process. Once the worker process has
/// registered with an external server, the process should create and
@@ -182,9 +168,9 @@ class WorkerPool {
std::unordered_map<ActorID, std::shared_ptr<Worker>> idle_actor;
/// All workers that have registered and are still connected, including both
/// idle and executing.
std::unordered_map<WorkerID, std::shared_ptr<Worker>> registered_workers;
std::unordered_set<std::shared_ptr<Worker>> registered_workers;
/// All drivers that have registered and are still connected.
std::unordered_map<WorkerID, std::shared_ptr<Worker>> registered_drivers;
std::unordered_set<std::shared_ptr<Worker>> registered_drivers;
/// A map from the pids of starting worker processes
/// to the number of their unregistered workers.
std::unordered_map<pid_t, int> starting_worker_processes;
+21 -9
View File
@@ -71,9 +71,19 @@ class WorkerPoolTest : public ::testing::Test {
std::shared_ptr<Worker> CreateWorker(pid_t pid,
const Language &language = Language::PYTHON) {
WorkerID worker_id = WorkerID::FromRandom();
return std::shared_ptr<Worker>(new Worker(
worker_id, pid, language, /* listening port */ -1, client_call_manager_));
std::function<void(LocalClientConnection &)> client_handler =
[this](LocalClientConnection &client) { HandleNewClient(client); };
std::function<void(std::shared_ptr<LocalClientConnection>, int64_t, const uint8_t *)>
message_handler = [this](std::shared_ptr<LocalClientConnection> client,
int64_t message_type, const uint8_t *message) {
HandleMessage(client, message_type, message);
};
boost::asio::local::stream_protocol::socket socket(io_service_);
auto client =
LocalClientConnection::Create(client_handler, message_handler, std::move(socket),
"worker", {}, error_message_type_);
return std::shared_ptr<Worker>(new Worker(WorkerID::FromRandom(), pid, language, -1,
client, client_call_manager_));
}
void SetWorkerCommands(const WorkerCommandMap &worker_commands) {
@@ -86,6 +96,10 @@ class WorkerPoolTest : public ::testing::Test {
boost::asio::io_service io_service_;
int64_t error_message_type_;
rpc::ClientCallManager client_call_manager_;
private:
void HandleNewClient(LocalClientConnection &){};
void HandleMessage(std::shared_ptr<LocalClientConnection>, int64_t, const uint8_t *){};
};
static inline TaskSpecification ExampleTaskSpec(
@@ -117,23 +131,21 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) {
workers.push_back(CreateWorker(pid));
}
for (const auto &worker : workers) {
WorkerID worker_id = worker->GetWorkerId();
// Check that there's still a starting worker process
// before all workers have been registered
ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), 1);
// Check that we cannot lookup the worker before it's registered.
ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker_id), nullptr);
RAY_CHECK_OK(worker_pool_.RegisterWorker(worker_id, worker));
ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), nullptr);
RAY_CHECK_OK(worker_pool_.RegisterWorker(worker));
// Check that we can lookup the worker after it's registered.
ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker_id), worker);
ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), worker);
}
// Check that there's no starting worker process
ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), 0);
for (const auto &worker : workers) {
WorkerID worker_id = worker->GetWorkerId();
worker_pool_.DisconnectWorker(worker);
// Check that we cannot lookup the worker after it's disconnected.
ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker_id), nullptr);
ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), nullptr);
}
}
+3 -13
View File
@@ -6,20 +6,14 @@ namespace ray {
namespace rpc {
void GrpcServer::Run() {
std::string server_address;
// Set unix domain socket or tcp address.
if (!unix_socket_path_.empty()) {
server_address = "unix://" + unix_socket_path_;
} else {
server_address = "0.0.0.0:" + std::to_string(port_);
}
std::string server_address("0.0.0.0:" + std::to_string(port_));
grpc::ServerBuilder builder;
// TODO(hchen): Add options for authentication.
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_);
// Register all the services to this server.
if (services_.empty()) {
RAY_LOG(WARNING) << "No service found when start grpc server " << name_;
RAY_LOG(WARNING) << "No service is found when start grpc server " << name_;
}
for (auto &entry : services_) {
builder.RegisterService(&entry.get());
@@ -29,11 +23,7 @@ void GrpcServer::Run() {
cq_ = builder.AddCompletionQueue();
// Build and start server.
server_ = builder.BuildAndStart();
if (unix_socket_path_.empty()) {
// For a TCP-based server, the actual port is decided after `AddListeningPort`.
server_address = "0.0.0.0:" + std::to_string(port_);
}
RAY_LOG(INFO) << name_ << " server started, listening on " << server_address;
RAY_LOG(INFO) << name_ << " server started, listening on port " << port_ << ".";
// Create calls for all the server call factories.
for (auto &entry : server_call_factories_and_concurrencies_) {
+1 -12
View File
@@ -32,16 +32,7 @@ class GrpcServer {
/// \param[in] port The port to bind this server to. If it's 0, a random available port
/// will be chosen.
GrpcServer(std::string name, const uint32_t port)
: name_(std::move(name)), port_(port), is_closed_(false) {}
/// Construct a gRPC server that listens on unix domain socket.
///
/// \param[in] name Name of this server, used for logging and debugging purpose.
/// \param[in] unix_socket_path Unix domain socket full path.
GrpcServer(std::string name, const std::string &unix_socket_path)
: GrpcServer(std::move(name), 0) {
unix_socket_path_ = unix_socket_path;
}
: name_(std::move(name)), port_(port), is_closed_(true) {}
/// Destruct this gRPC server.
~GrpcServer() { Shutdown(); }
@@ -82,8 +73,6 @@ class GrpcServer {
int port_;
/// Indicates whether this server has been closed.
bool is_closed_;
/// Unix domain socket path.
std::string unix_socket_path_;
/// The `grpc::Service` objects which should be registered to `ServerBuilder`.
std::vector<std::reference_wrapper<grpc::Service>> services_;
/// The `ServerCallFactory` objects, and the maximum number of concurrent requests that
@@ -5,7 +5,6 @@
#include <grpcpp/grpcpp.h>
#include "ray/common/grpc_util.h"
#include "ray/common/status.h"
#include "ray/rpc/client_call.h"
#include "ray/util/logging.h"
@@ -5,7 +5,6 @@
#include <grpcpp/grpcpp.h>
#include "ray/common/grpc_util.h"
#include "ray/common/status.h"
#include "ray/util/logging.h"
#include "src/ray/protobuf/object_manager.grpc.pb.h"
-440
View File
@@ -1,440 +0,0 @@
#include "src/ray/rpc/raylet/raylet_client.h"
namespace ray {
namespace rpc {
#define RETURN_IF_DISCONNECTED(connected) \
do { \
if (!(connected)) { \
return Status::Invalid("Raylet connection is closed."); \
} \
} while (0)
RayletClient::RayletClient(const std::string &raylet_socket, const WorkerID &worker_id,
bool is_worker, const JobID &job_id, const Language &language,
int port)
: worker_id_(worker_id),
is_worker_(is_worker),
job_id_(job_id),
language_(language),
port_(port),
main_service_(),
work_(main_service_),
client_call_manager_(main_service_),
heartbeat_timer_(main_service_),
is_connected_(false) {
std::shared_ptr<grpc::Channel> channel =
grpc::CreateChannel("unix://" + raylet_socket, grpc::InsecureChannelCredentials());
stub_ = RayletService::NewStub(channel);
rpc_thread_ = std::thread([this]() { main_service_.run(); });
RAY_LOG(DEBUG) << "Connecting to unix socket: "
<< "unix://" + raylet_socket
<< ", is worker: " << (is_worker_ ? "true" : "false")
<< ", worker id: " << worker_id;
// Try to register client `num_raylet_client_retry_times` times.
TryRegisterClient(RayConfig::instance().num_raylet_client_retry_times());
}
void RayletClient::TryRegisterClient(int retry_times) {
// We should block here until register succeeds.
for (int i = 0; i < retry_times; i++) {
auto st = RegisterClient();
if (st.ok()) {
is_connected_ = true;
Heartbeat();
return;
}
std::this_thread::sleep_for(std::chrono::milliseconds(500));
}
RAY_LOG(FATAL) << "Worker " << worker_id_
<< " failed to register to raylet server, worker id: " << worker_id_
<< ", pid: " << static_cast<int>(getpid())
<< ", is worker: " << is_worker_;
}
RayletClient::~RayletClient() {
is_connected_ = false;
main_service_.stop();
rpc_thread_.join();
}
ray::Status RayletClient::Disconnect() {
DisconnectClientRequest disconnect_client_request;
disconnect_client_request.set_worker_id(worker_id_.Binary());
DisconnectClientReply reply;
grpc::ClientContext context;
auto status = stub_->DisconnectClient(&context, disconnect_client_request, &reply);
if (!status.ok()) {
RAY_LOG(ERROR) << "Worker " << worker_id_
<< " failed to disconnect from raylet, msg: "
<< status.error_message();
}
return GrpcStatusToRayStatus(status);
}
ray::Status RayletClient::SubmitTask(const ray::TaskSpecification &task_spec) {
RETURN_IF_DISCONNECTED(is_connected_);
SubmitTaskRequest submit_task_request;
submit_task_request.set_worker_id(worker_id_.Binary());
submit_task_request.mutable_task_spec()->CopyFrom(task_spec.GetMessage());
auto callback = [this](const Status &status, const SubmitTaskReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send SubmitTaskRequest, msg: " << status.message();
}
};
auto call =
client_call_manager_.CreateCall<RayletService, SubmitTaskRequest, SubmitTaskReply>(
*stub_, &RayletService::Stub::PrepareAsyncSubmitTask, submit_task_request,
callback);
return call->GetStatus();
}
ray::Status RayletClient::GetTask(std::unique_ptr<ray::TaskSpecification> *task_spec) {
RETURN_IF_DISCONNECTED(is_connected_);
GetTaskRequest get_task_request;
get_task_request.set_worker_id(worker_id_.Binary());
grpc::ClientContext context;
GetTaskReply reply;
// The actual RPC.
auto status = stub_->GetTask(&context, get_task_request, &reply);
if (status.ok()) {
resource_ids_.clear();
// Parse resources that would be used by this assigned task.
for (int64_t i = 0; i < reply.fractional_resource_ids().size(); ++i) {
auto const &fractional_resource_ids = reply.fractional_resource_ids()[i];
auto &acquired_resources = resource_ids_[fractional_resource_ids.resource_name()];
// Each resource includes a series of resource IDs (e.g., GPU 0) and corresponding
// amount for that resource ID. If the resource amount is fractional, then there
// should only be one resource ID.
size_t num_resource_ids = fractional_resource_ids.resource_ids().size();
size_t num_resource_fractions = fractional_resource_ids.resource_fractions().size();
RAY_CHECK(num_resource_ids == num_resource_fractions);
RAY_CHECK(num_resource_ids > 0);
for (size_t j = 0; j < num_resource_ids; ++j) {
int64_t resource_id = fractional_resource_ids.resource_ids()[j];
double resource_fraction = fractional_resource_ids.resource_fractions()[j];
if (num_resource_ids > 1) {
int64_t whole_fraction = resource_fraction;
RAY_CHECK(whole_fraction == resource_fraction);
}
acquired_resources.emplace_back(resource_id, resource_fraction);
}
}
task_spec->reset(new ray::TaskSpecification(reply.task_spec()));
} else {
*task_spec = nullptr;
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to get task, msg: " << status.error_message();
}
return GrpcStatusToRayStatus(status);
}
ray::Status RayletClient::TaskDone() {
RETURN_IF_DISCONNECTED(is_connected_);
TaskDoneRequest task_done_request;
task_done_request.set_worker_id(worker_id_.Binary());
auto callback = [this](const Status &status, const TaskDoneReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send TaskDoneRequest, msg: " << status.message();
}
};
auto call =
client_call_manager_.CreateCall<RayletService, TaskDoneRequest, TaskDoneReply>(
*stub_, &RayletService::Stub::PrepareAsyncTaskDone, task_done_request,
callback);
return call->GetStatus();
}
ray::Status RayletClient::FetchOrReconstruct(const std::vector<ObjectID> &object_ids,
bool fetch_only,
const TaskID &current_task_id) {
RETURN_IF_DISCONNECTED(is_connected_);
FetchOrReconstructRequest fetch_or_reconstruct_request;
fetch_or_reconstruct_request.set_fetch_only(fetch_only);
fetch_or_reconstruct_request.set_task_id(current_task_id.Binary());
fetch_or_reconstruct_request.set_worker_id(worker_id_.Binary());
IdVectorToProtobuf<ObjectID, FetchOrReconstructRequest>(
object_ids, fetch_or_reconstruct_request,
&FetchOrReconstructRequest::add_object_ids);
// Callback to deal with reply.
auto callback = [this](const Status &status, const FetchOrReconstructReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send FetchOrReconstructRequest, msg: "
<< status.message();
}
};
auto call =
client_call_manager_
.CreateCall<RayletService, FetchOrReconstructRequest, FetchOrReconstructReply>(
*stub_, &RayletService::Stub::PrepareAsyncFetchOrReconstruct,
fetch_or_reconstruct_request, callback);
return call->GetStatus();
}
ray::Status RayletClient::NotifyUnblocked(const TaskID &current_task_id) {
RETURN_IF_DISCONNECTED(is_connected_);
NotifyUnblockedRequest notify_unblocked_request;
notify_unblocked_request.set_worker_id(worker_id_.Binary());
notify_unblocked_request.set_task_id(current_task_id.Binary());
auto callback = [this](const Status &status, const NotifyUnblockedReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send NotifyUnblockedRequest, msg: "
<< status.message();
}
};
auto call =
client_call_manager_
.CreateCall<RayletService, NotifyUnblockedRequest, NotifyUnblockedReply>(
*stub_, &RayletService::Stub::PrepareAsyncNotifyUnblocked,
notify_unblocked_request, callback);
return call->GetStatus();
}
ray::Status RayletClient::Wait(const std::vector<ObjectID> &object_ids, int num_returns,
int64_t timeout_milliseconds, bool wait_local,
const TaskID &current_task_id, WaitResultPair *result) {
RETURN_IF_DISCONNECTED(is_connected_);
WaitRequest wait_request;
wait_request.set_worker_id(worker_id_.Binary());
wait_request.set_timeout(timeout_milliseconds);
wait_request.set_wait_local(wait_local);
wait_request.set_task_id(current_task_id.Binary());
wait_request.set_num_ready_objects(num_returns);
IdVectorToProtobuf<ObjectID, WaitRequest>(object_ids, wait_request,
&WaitRequest::add_object_ids);
grpc::ClientContext context;
WaitReply reply;
auto status = stub_->Wait(&context, wait_request, &reply);
if (status.ok()) {
result->first = IdVectorFromProtobuf<ObjectID>(reply.found());
result->second = IdVectorFromProtobuf<ObjectID>(reply.remaining());
} else {
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send WaitRequest, msg: " << status.error_message();
}
return GrpcStatusToRayStatus(status);
}
ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string &type,
const std::string &error_message, double timestamp) {
RETURN_IF_DISCONNECTED(is_connected_);
PushErrorRequest push_error_request;
push_error_request.set_job_id(job_id.Binary());
push_error_request.set_type(type);
push_error_request.set_error_message(error_message);
push_error_request.set_timestamp(timestamp);
push_error_request.set_worker_id(worker_id_.Binary());
auto callback = [this](const Status &status, const PushErrorReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send PushErrorRequest, msg: " << status.message();
}
};
auto call =
client_call_manager_.CreateCall<RayletService, PushErrorRequest, PushErrorReply>(
*stub_, &RayletService::Stub::PrepareAsyncPushError, push_error_request,
callback);
return call->GetStatus();
}
ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_events) {
RETURN_IF_DISCONNECTED(is_connected_);
PushProfileEventsRequest push_profile_events_request;
push_profile_events_request.mutable_profile_table_data()->CopyFrom(profile_events);
push_profile_events_request.set_worker_id(worker_id_.Binary());
auto callback = [this](const Status &status, const PushProfileEventsReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send PushProfileEventsRequest, msg: "
<< status.message();
}
};
auto call =
client_call_manager_
.CreateCall<RayletService, PushProfileEventsRequest, PushProfileEventsReply>(
*stub_, &RayletService::Stub::PrepareAsyncPushProfileEvents,
push_profile_events_request, callback);
return call->GetStatus();
}
ray::Status RayletClient::FreeObjects(const std::vector<ray::ObjectID> &object_ids,
bool local_only, bool delete_creating_tasks) {
RETURN_IF_DISCONNECTED(is_connected_);
FreeObjectsInStoreRequest free_objects_request;
free_objects_request.set_local_only(local_only);
free_objects_request.set_delete_creating_tasks(delete_creating_tasks);
free_objects_request.set_worker_id(worker_id_.Binary());
IdVectorToProtobuf<ray::ObjectID, FreeObjectsInStoreRequest>(
object_ids, free_objects_request, &FreeObjectsInStoreRequest::add_object_ids);
auto callback = [this](const Status &status, const FreeObjectsInStoreReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send FreeObjectsInStoreRequest, msg: "
<< status.message();
}
};
auto call =
client_call_manager_
.CreateCall<RayletService, FreeObjectsInStoreRequest, FreeObjectsInStoreReply>(
*stub_, &RayletService::Stub::PrepareAsyncFreeObjectsInStore,
free_objects_request, callback);
return call->GetStatus();
}
ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id,
ActorCheckpointID &checkpoint_id) {
RETURN_IF_DISCONNECTED(is_connected_);
PrepareActorCheckpointRequest prepare_actor_checkpoint_request;
prepare_actor_checkpoint_request.set_actor_id(actor_id.Binary());
prepare_actor_checkpoint_request.set_worker_id(worker_id_.Binary());
grpc::ClientContext context;
PrepareActorCheckpointReply reply;
auto status =
stub_->PrepareActorCheckpoint(&context, prepare_actor_checkpoint_request, &reply);
if (status.ok()) {
checkpoint_id = ActorCheckpointID::FromBinary(reply.checkpoint_id());
} else {
RAY_LOG(INFO) << "Worker " << worker_id_
<< " failed to send PrepareActorCheckpointRequest, msg: "
<< status.error_message();
}
return GrpcStatusToRayStatus(status);
}
ray::Status RayletClient::NotifyActorResumedFromCheckpoint(
const ActorID &actor_id, const ActorCheckpointID &checkpoint_id) {
RETURN_IF_DISCONNECTED(is_connected_);
NotifyActorResumedFromCheckpointRequest notify_actor_resumed_from_checkpoint_request;
notify_actor_resumed_from_checkpoint_request.set_actor_id(actor_id.Binary());
notify_actor_resumed_from_checkpoint_request.set_checkpoint_id(checkpoint_id.Binary());
notify_actor_resumed_from_checkpoint_request.set_worker_id(worker_id_.Binary());
auto callback = [this](const Status &status,
const NotifyActorResumedFromCheckpointReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "NotifyActorResumedFromCheckpoint failed, msg: "
<< status.message();
}
};
auto call =
client_call_manager_
.CreateCall<RayletService, NotifyActorResumedFromCheckpointRequest,
NotifyActorResumedFromCheckpointReply>(
*stub_, &RayletService::Stub::PrepareAsyncNotifyActorResumedFromCheckpoint,
notify_actor_resumed_from_checkpoint_request, callback);
return call->GetStatus();
}
ray::Status RayletClient::SetResource(const std::string &resource_name,
const double capacity,
const ray::ClientID &client_id) {
RETURN_IF_DISCONNECTED(is_connected_);
SetResourceRequest set_resource_request;
set_resource_request.set_resource_name(resource_name);
set_resource_request.set_capacity(capacity);
set_resource_request.set_client_id(client_id.Binary());
set_resource_request.set_worker_id(worker_id_.Binary());
auto callback = [this](const Status &status, const SetResourceReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "SetResource failed, msg: " << status.message();
}
};
auto call = client_call_manager_
.CreateCall<RayletService, SetResourceRequest, SetResourceReply>(
*stub_, &RayletService::Stub::PrepareAsyncSetResource,
set_resource_request, callback);
return call->GetStatus();
}
ray::Status RayletClient::RegisterClient() {
RegisterClientRequest register_client_request;
register_client_request.set_is_worker(is_worker_);
register_client_request.set_worker_id(worker_id_.Binary());
register_client_request.set_worker_pid(getpid());
register_client_request.set_job_id(job_id_.Binary());
register_client_request.set_language(language_);
register_client_request.set_port(port_);
grpc::ClientContext context;
RegisterClientReply reply;
auto status = stub_->RegisterClient(&context, register_client_request, &reply);
if (!status.ok()) {
RAY_LOG(DEBUG) << "Worker " << worker_id_
<< " failed to register client, msg: " << status.error_message();
}
return GrpcStatusToRayStatus(status);
}
void RayletClient::Heartbeat() {
if (!is_connected_) {
return;
}
HeartbeatRequest heartbeat_request;
heartbeat_request.set_is_worker(is_worker_);
heartbeat_request.set_worker_id(worker_id_.Binary());
auto callback = [this](const Status &status, const HeartbeatReply &reply) {
if (!status.ok() && is_connected_) {
is_connected_ = false;
RAY_LOG(INFO) << "Heartbeat failed, msg: " << status.message();
}
};
auto call =
client_call_manager_.CreateCall<RayletService, HeartbeatRequest, HeartbeatReply>(
*stub_, &RayletService::Stub::PrepareAsyncHeartbeat, heartbeat_request,
callback);
heartbeat_timer_.expires_from_now(boost::posix_time::milliseconds(
RayConfig::instance().heartbeat_timeout_milliseconds()));
heartbeat_timer_.async_wait([this](const boost::system::error_code &error) {
RAY_CHECK(!error);
Heartbeat();
});
}
} // namespace rpc
} // namespace ray
-258
View File
@@ -1,258 +0,0 @@
#ifndef RAY_RPC_RAYLET_SERVER_H
#define RAY_RPC_RAYLET_SERVER_H
#include "src/ray/rpc/grpc_server.h"
#include "src/ray/rpc/server_call.h"
#include "src/ray/protobuf/raylet.grpc.pb.h"
#include "src/ray/protobuf/raylet.pb.h"
namespace ray {
namespace rpc {
/// Implementations of the `RayletService`, check interface in
/// `src/ray/protobuf/raylet.proto`.
class RayletServiceHandler {
public:
virtual void HandleRegisterClientRequest(const RegisterClientRequest &request,
RegisterClientReply *reply,
SendReplyCallback send_reply_callback) = 0;
/// Handle a `SubmitTask` request.
/// The implementation can handle this request asynchronously. When handling is done,
/// the `send_reply_callback` should be called.
///
/// \param[in] request The request message.
/// \param[out] reply The reply message.
/// \param[in] send_reply_callback The callback to be called when the request is done.
virtual void HandleSubmitTaskRequest(const SubmitTaskRequest &request,
SubmitTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;
/// Handle a `DisconnectClient` request.
virtual void HandleDisconnectClientRequest(const DisconnectClientRequest &request,
DisconnectClientReply *reply,
SendReplyCallback send_reply_callback) = 0;
/// Handle a `GetTask` request.
virtual void HandleGetTaskRequest(const GetTaskRequest &request, GetTaskReply *reply,
SendReplyCallback send_reply_callback) = 0;
/// Handle a `TaskDone` request.
virtual void HandleTaskDoneRequest(const TaskDoneRequest &request, TaskDoneReply *reply,
SendReplyCallback send_reply_callback) = 0;
/// Handle a `HandleFetchOrReconstruct` request.
virtual void HandleFetchOrReconstructRequest(const FetchOrReconstructRequest &request,
FetchOrReconstructReply *reply,
SendReplyCallback send_reply_callback) = 0;
/// Handle a `HandleNotifyUnblocked` request.
virtual void HandleNotifyUnblockedRequest(const NotifyUnblockedRequest &request,
NotifyUnblockedReply *reply,
SendReplyCallback send_reply_callback) = 0;
/// Handle a `Wait` request.
virtual void HandleWaitRequest(const WaitRequest &request, WaitReply *reply,
SendReplyCallback send_reply_callback) = 0;
/// Handle a `PushError` request.
virtual void HandlePushErrorRequest(const PushErrorRequest &request,
PushErrorReply *reply,
SendReplyCallback send_reply_callback) = 0;
/// Handle a `PushProfileEvents` request.
virtual void HandlePushProfileEventsRequest(const PushProfileEventsRequest &request,
PushProfileEventsReply *reply,
SendReplyCallback send_reply_callback) = 0;
/// Handle a `FreeObjectsInStoreInObjectStore` request.
virtual void HandleFreeObjectsInStoreRequest(const FreeObjectsInStoreRequest &request,
FreeObjectsInStoreReply *reply,
SendReplyCallback send_reply_callback) = 0;
/// Handle a `PrepareActorCheckpoint` request.
virtual void HandlePrepareActorCheckpointRequest(
const PrepareActorCheckpointRequest &request, PrepareActorCheckpointReply *reply,
SendReplyCallback send_reply_callback) = 0;
/// Handle a `NotifyActorResumedFromCheckpoint` request.
virtual void HandleNotifyActorResumedFromCheckpointRequest(
const NotifyActorResumedFromCheckpointRequest &request,
NotifyActorResumedFromCheckpointReply *reply,
SendReplyCallback send_reply_callback) = 0;
/// Handle a `SetResource` request.
virtual void HandleSetResourceRequest(const SetResourceRequest &request,
SetResourceReply *reply,
SendReplyCallback send_reply_callback) = 0;
/// Handle a `SetResourceReply` request.
virtual void HandleHeartbeatRequest(const HeartbeatRequest &request,
HeartbeatReply *reply,
SendReplyCallback send_reply_callback) = 0;
};
/// The `GrpcService` for `RayletGrpcService`.
class RayletGrpcService : public GrpcService {
public:
/// Construct a `RayletGrpcService`.
///
/// \param[in] io_service Service used to handle incoming requests
/// \param[in] handler The service handler that actually handle the requests.
RayletGrpcService(boost::asio::io_service &io_service,
RayletServiceHandler &service_handler)
: GrpcService(io_service), service_handler_(service_handler){};
protected:
grpc::Service &GetGrpcService() override { return service_; }
void InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
*server_call_factories_and_concurrencies) override {
// Initialize the factory for `RegisterClient` requests.
std::unique_ptr<ServerCallFactory> register_client_call_factory(
new ServerCallFactoryImpl<RayletService, RayletServiceHandler,
RegisterClientRequest, RegisterClientReply>(
service_, &RayletService::AsyncService::RequestRegisterClient,
service_handler_, &RayletServiceHandler::HandleRegisterClientRequest, cq,
main_service_));
server_call_factories_and_concurrencies->emplace_back(
std::move(register_client_call_factory), 10);
// Initialize the factory for `SubmitTask` requests.
std::unique_ptr<ServerCallFactory> submit_task_call_factory(
new ServerCallFactoryImpl<RayletService, RayletServiceHandler, SubmitTaskRequest,
SubmitTaskReply>(
service_, &RayletService::AsyncService::RequestSubmitTask, service_handler_,
&RayletServiceHandler::HandleSubmitTaskRequest, cq, main_service_));
server_call_factories_and_concurrencies->emplace_back(
std::move(submit_task_call_factory), 20);
// Initialize the factory for `DisconnectClient` requests.
std::unique_ptr<ServerCallFactory> disconnect_client_call_factory(
new ServerCallFactoryImpl<RayletService, RayletServiceHandler,
DisconnectClientRequest, DisconnectClientReply>(
service_, &RayletService::AsyncService::RequestDisconnectClient,
service_handler_, &RayletServiceHandler::HandleDisconnectClientRequest, cq,
main_service_));
server_call_factories_and_concurrencies->emplace_back(
std::move(disconnect_client_call_factory), 10);
// Initialize the factory for `GetTask` requests.
std::unique_ptr<ServerCallFactory> get_task_call_factory(
new ServerCallFactoryImpl<RayletService, RayletServiceHandler, GetTaskRequest,
GetTaskReply>(
service_, &RayletService::AsyncService::RequestGetTask, service_handler_,
&RayletServiceHandler::HandleGetTaskRequest, cq, main_service_));
server_call_factories_and_concurrencies->emplace_back(
std::move(get_task_call_factory), 20);
// Initialize the factory for `TaskDone` requests.
std::unique_ptr<ServerCallFactory> task_done_call_factory(
new ServerCallFactoryImpl<RayletService, RayletServiceHandler, TaskDoneRequest,
TaskDoneReply>(
service_, &RayletService::AsyncService::RequestTaskDone, service_handler_,
&RayletServiceHandler::HandleTaskDoneRequest, cq, main_service_));
server_call_factories_and_concurrencies->emplace_back(
std::move(task_done_call_factory), 20);
// Initialize the factory for `FetchOrReconstruct` requests.
std::unique_ptr<ServerCallFactory> fetch_or_reconstruct_call_factory(
new ServerCallFactoryImpl<RayletService, RayletServiceHandler,
FetchOrReconstructRequest, FetchOrReconstructReply>(
service_, &RayletService::AsyncService::RequestFetchOrReconstruct,
service_handler_, &RayletServiceHandler::HandleFetchOrReconstructRequest, cq,
main_service_));
server_call_factories_and_concurrencies->emplace_back(
std::move(fetch_or_reconstruct_call_factory), 10);
// Initialize the factory for `NotifyUnblocked` requests.
std::unique_ptr<ServerCallFactory> notify_unblocked_call_factory(
new ServerCallFactoryImpl<RayletService, RayletServiceHandler,
NotifyUnblockedRequest, NotifyUnblockedReply>(
service_, &RayletService::AsyncService::RequestNotifyUnblocked,
service_handler_, &RayletServiceHandler::HandleNotifyUnblockedRequest, cq,
main_service_));
server_call_factories_and_concurrencies->emplace_back(
std::move(notify_unblocked_call_factory), 10);
// Initialize the factory for `Wait` requests.
std::unique_ptr<ServerCallFactory> wait_call_factory(
new ServerCallFactoryImpl<RayletService, RayletServiceHandler, WaitRequest,
WaitReply>(
service_, &RayletService::AsyncService::RequestWait, service_handler_,
&RayletServiceHandler::HandleWaitRequest, cq, main_service_));
server_call_factories_and_concurrencies->emplace_back(std::move(wait_call_factory),
20);
// Initialize the factory for `PushError` requests.
std::unique_ptr<ServerCallFactory> push_error_call_factory(
new ServerCallFactoryImpl<RayletService, RayletServiceHandler, PushErrorRequest,
PushErrorReply>(
service_, &RayletService::AsyncService::RequestPushError, service_handler_,
&RayletServiceHandler::HandlePushErrorRequest, cq, main_service_));
server_call_factories_and_concurrencies->emplace_back(
std::move(push_error_call_factory), 10);
// Initialize the factory for `PushProfileEvents` requests.
std::unique_ptr<ServerCallFactory> push_profile_events_call_factory(
new ServerCallFactoryImpl<RayletService, RayletServiceHandler,
PushProfileEventsRequest, PushProfileEventsReply>(
service_, &RayletService::AsyncService::RequestPushProfileEvents,
service_handler_, &RayletServiceHandler::HandlePushProfileEventsRequest, cq,
main_service_));
server_call_factories_and_concurrencies->emplace_back(
std::move(push_profile_events_call_factory), 10);
// Initialize the factory for `FreeObjectsInStore` requests.
std::unique_ptr<ServerCallFactory> free_objects_call_factory(
new ServerCallFactoryImpl<RayletService, RayletServiceHandler,
FreeObjectsInStoreRequest, FreeObjectsInStoreReply>(
service_, &RayletService::AsyncService::RequestFreeObjectsInStore,
service_handler_, &RayletServiceHandler::HandleFreeObjectsInStoreRequest, cq,
main_service_));
server_call_factories_and_concurrencies->emplace_back(
std::move(free_objects_call_factory), 10);
// Initialize the factory for `PrepareActorCheckpoint` requests.
std::unique_ptr<ServerCallFactory> prepare_actor_checkpoint_call_factory(
new ServerCallFactoryImpl<RayletService, RayletServiceHandler,
PrepareActorCheckpointRequest,
PrepareActorCheckpointReply>(
service_, &RayletService::AsyncService::RequestPrepareActorCheckpoint,
service_handler_, &RayletServiceHandler::HandlePrepareActorCheckpointRequest,
cq, main_service_));
server_call_factories_and_concurrencies->emplace_back(
std::move(prepare_actor_checkpoint_call_factory), 10);
// Initialize the factory for `NotifyActorResumedFromCheckpoint` requests.
std::unique_ptr<ServerCallFactory> notify_actor_resumed_from_checkpoint_call_factory(
new ServerCallFactoryImpl<RayletService, RayletServiceHandler,
NotifyActorResumedFromCheckpointRequest,
NotifyActorResumedFromCheckpointReply>(
service_,
&RayletService::AsyncService::RequestNotifyActorResumedFromCheckpoint,
service_handler_,
&RayletServiceHandler::HandleNotifyActorResumedFromCheckpointRequest, cq,
main_service_));
server_call_factories_and_concurrencies->emplace_back(
std::move(notify_actor_resumed_from_checkpoint_call_factory), 10);
// Initialize the factory for `SetResource` requests.
std::unique_ptr<ServerCallFactory> set_resource_call_factory(
new ServerCallFactoryImpl<RayletService, RayletServiceHandler, SetResourceRequest,
SetResourceReply>(
service_, &RayletService::AsyncService::RequestSetResource, service_handler_,
&RayletServiceHandler::HandleSetResourceRequest, cq, main_service_));
server_call_factories_and_concurrencies->emplace_back(
std::move(set_resource_call_factory), 10);
// Initialize the factory for `Heartbeat` requests.
std::unique_ptr<ServerCallFactory> heartbeat_call_factory(
new ServerCallFactoryImpl<RayletService, RayletServiceHandler, HeartbeatRequest,
HeartbeatReply>(
service_, &RayletService::AsyncService::RequestHeartbeat, service_handler_,
&RayletServiceHandler::HandleHeartbeatRequest, cq, main_service_));
server_call_factories_and_concurrencies->emplace_back(
std::move(heartbeat_call_factory), 10);
}
private:
/// The grpc async service object.
RayletService::AsyncService service_;
/// The service handler that actually handle the requests.
RayletServiceHandler &service_handler_;
};
} // namespace rpc
} // namespace ray
#endif