mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 10:28:01 +08:00
Revert raylet to worker GRPC communication back to asio (#5450)
This commit is contained in:
+12
-43
@@ -65,20 +65,6 @@ cc_proto_library(
|
||||
deps = [":object_manager_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "raylet_proto",
|
||||
srcs = ["src/ray/protobuf/raylet.proto"],
|
||||
deps = [
|
||||
":common_proto",
|
||||
":gcs_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_proto_library(
|
||||
name = "raylet_cc_proto",
|
||||
deps = [":raylet_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "worker_proto",
|
||||
srcs = ["src/ray/protobuf/worker.proto"],
|
||||
@@ -182,34 +168,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
# Raylet gRPC lib.
|
||||
cc_grpc_library(
|
||||
name = "raylet_cc_grpc",
|
||||
srcs = [":raylet_proto"],
|
||||
grpc_only = True,
|
||||
deps = [":raylet_cc_proto"],
|
||||
)
|
||||
|
||||
# Raylet rpc server and client.
|
||||
cc_library(
|
||||
name = "raylet_rpc",
|
||||
srcs = glob([
|
||||
"src/ray/rpc/raylet/*.cc",
|
||||
]),
|
||||
hdrs = glob([
|
||||
"src/ray/rpc/raylet/*.h",
|
||||
"src/ray/raylet/*.h",
|
||||
]),
|
||||
copts = COPTS,
|
||||
deps = [
|
||||
":grpc_common_lib",
|
||||
":ray_common",
|
||||
":raylet_cc_grpc",
|
||||
"@boost//:asio",
|
||||
"@com_github_grpc_grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
# Worker gRPC lib.
|
||||
cc_grpc_library(
|
||||
name = "worker_cc_grpc",
|
||||
@@ -263,6 +221,7 @@ cc_library(
|
||||
copts = COPTS,
|
||||
deps = [
|
||||
":common_cc_proto",
|
||||
":node_manager_fbs",
|
||||
":ray_util",
|
||||
"@boost//:asio",
|
||||
"@com_github_grpc_grpc//:grpc++",
|
||||
@@ -361,11 +320,11 @@ cc_library(
|
||||
deps = [
|
||||
":common_cc_proto",
|
||||
":gcs",
|
||||
":node_manager_fbs",
|
||||
":node_manager_rpc",
|
||||
":object_manager",
|
||||
":ray_common",
|
||||
":ray_util",
|
||||
":raylet_rpc",
|
||||
":stats_lib",
|
||||
":worker_rpc",
|
||||
"@boost//:asio",
|
||||
@@ -461,6 +420,7 @@ cc_test(
|
||||
srcs = ["src/ray/raylet/lineage_cache_test.cc"],
|
||||
copts = COPTS,
|
||||
deps = [
|
||||
":node_manager_fbs",
|
||||
":raylet_lib",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
@@ -471,6 +431,7 @@ cc_test(
|
||||
srcs = ["src/ray/raylet/reconstruction_policy_test.cc"],
|
||||
copts = COPTS,
|
||||
deps = [
|
||||
":node_manager_fbs",
|
||||
":object_manager",
|
||||
":raylet_lib",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
@@ -670,6 +631,7 @@ cc_library(
|
||||
deps = [
|
||||
":gcs_cc_proto",
|
||||
":hiredis",
|
||||
":node_manager_fbs",
|
||||
":node_manager_rpc",
|
||||
":ray_common",
|
||||
":ray_util",
|
||||
@@ -725,6 +687,13 @@ flatbuffer_cc_library(
|
||||
out_prefix = "src/ray/common/",
|
||||
)
|
||||
|
||||
flatbuffer_cc_library(
|
||||
name = "node_manager_fbs",
|
||||
srcs = ["src/ray/raylet/format/node_manager.fbs"],
|
||||
flatc_args = FLATC_ARGS,
|
||||
out_prefix = "src/ray/raylet/format/",
|
||||
)
|
||||
|
||||
flatbuffer_cc_library(
|
||||
name = "object_manager_fbs",
|
||||
srcs = ["src/ray/object_manager/format/object_manager.fbs"],
|
||||
|
||||
@@ -24,7 +24,11 @@ public class BaseTest {
|
||||
// These files need to be deleted after each test case.
|
||||
filesToDelete = ImmutableList.of(
|
||||
new File(Ray.getRuntimeContext().getRayletSocketName()),
|
||||
new File(Ray.getRuntimeContext().getObjectStoreSocketName())
|
||||
new File(Ray.getRuntimeContext().getObjectStoreSocketName()),
|
||||
// TODO(pcm): This is a workaround for the issue described
|
||||
// in the PR description of https://github.com/ray-project/ray/pull/5450
|
||||
// and should be fixed properly.
|
||||
new File("/tmp/ray/test/raylet_socket")
|
||||
);
|
||||
// Make sure the files will be deleted even if the test doesn't exit gracefully.
|
||||
filesToDelete.forEach(File::deleteOnExit);
|
||||
|
||||
@@ -102,4 +102,3 @@ public class ResourcesManagementTest extends BaseTest {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -374,7 +374,7 @@ cdef class RayletClient:
|
||||
|
||||
@property
|
||||
def client_id(self):
|
||||
return ClientID(self.client.get().GetWorkerId().Binary())
|
||||
return ClientID(self.client.get().GetWorkerID().Binary())
|
||||
|
||||
@property
|
||||
def job_id(self):
|
||||
|
||||
@@ -88,16 +88,16 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
|
||||
|
||||
|
||||
cdef extern from "ray/protobuf/common.pb.h" nogil:
|
||||
cdef cppclass CLanguage "ray::rpc::Language":
|
||||
cdef cppclass CLanguage "Language":
|
||||
pass
|
||||
|
||||
|
||||
# This is a workaround for C++ enum class since Cython has no corresponding
|
||||
# representation.
|
||||
cdef extern from "ray/protobuf/common.pb.h" namespace "ray::rpc::Language" nogil:
|
||||
cdef CLanguage LANGUAGE_PYTHON "ray::rpc::Language::PYTHON"
|
||||
cdef CLanguage LANGUAGE_CPP "ray::rpc::Language::CPP"
|
||||
cdef CLanguage LANGUAGE_JAVA "ray::rpc::Language::JAVA"
|
||||
cdef extern from "ray/protobuf/common.pb.h" namespace "Language" nogil:
|
||||
cdef CLanguage LANGUAGE_PYTHON "Language::PYTHON"
|
||||
cdef CLanguage LANGUAGE_CPP "Language::CPP"
|
||||
cdef CLanguage LANGUAGE_JAVA "Language::JAVA"
|
||||
|
||||
|
||||
cdef extern from "ray/common/task/scheduling_resources.h" \
|
||||
|
||||
@@ -23,14 +23,14 @@ from ray.includes.task cimport CTaskSpec
|
||||
|
||||
|
||||
cdef extern from "ray/protobuf/gcs.pb.h" nogil:
|
||||
cdef cppclass GCSProfileEvent "ray::rpc::ProfileTableData::ProfileEvent":
|
||||
cdef cppclass GCSProfileEvent "ProfileTableData::ProfileEvent":
|
||||
void set_event_type(const c_string &value)
|
||||
void set_start_time(double value)
|
||||
void set_end_time(double value)
|
||||
c_string set_extra_data(const c_string &value)
|
||||
GCSProfileEvent()
|
||||
|
||||
cdef cppclass GCSProfileTableData "ray::rpc::ProfileTableData":
|
||||
cdef cppclass GCSProfileTableData "ProfileTableData":
|
||||
void set_component_type(const c_string &value)
|
||||
void set_component_id(const c_string &value)
|
||||
void set_node_ip_address(const c_string &value)
|
||||
@@ -43,12 +43,13 @@ ctypedef unordered_map[c_string, c_vector[pair[int64_t, double]]] \
|
||||
ctypedef pair[c_vector[CObjectID], c_vector[CObjectID]] WaitResultPair
|
||||
|
||||
|
||||
cdef extern from "ray/rpc/raylet/raylet_client.h" namespace "ray::rpc" nogil:
|
||||
cdef cppclass CRayletClient "ray::rpc::RayletClient":
|
||||
cdef extern from "ray/raylet/raylet_client.h" nogil:
|
||||
cdef cppclass CRayletClient "RayletClient":
|
||||
CRayletClient(const c_string &raylet_socket,
|
||||
const CWorkerID &worker_id,
|
||||
c_bool is_worker, const CJobID &job_id,
|
||||
const CLanguage &language)
|
||||
CRayStatus Disconnect()
|
||||
CRayStatus SubmitTask(const CTaskSpec &task_spec)
|
||||
CRayStatus GetTask(unique_ptr[CTaskSpec] *task_spec)
|
||||
CRayStatus TaskDone()
|
||||
@@ -72,8 +73,7 @@ cdef extern from "ray/rpc/raylet/raylet_client.h" namespace "ray::rpc" nogil:
|
||||
const CActorID &actor_id, const CActorCheckpointID &checkpoint_id)
|
||||
CRayStatus SetResource(const c_string &resource_name, const double capacity, const CClientID &client_Id)
|
||||
CLanguage GetLanguage() const
|
||||
CWorkerID GetWorkerId() const
|
||||
CWorkerID GetWorkerID() const
|
||||
CJobID GetJobID() const
|
||||
c_bool IsWorker() const
|
||||
CRayStatus Disconnect()
|
||||
const ResourceMappingType &GetResourceIDs() const
|
||||
|
||||
@@ -82,17 +82,6 @@ inline std::vector<T> VectorFromProtobuf(
|
||||
return std::vector<T>(pb_repeated.begin(), pb_repeated.end());
|
||||
}
|
||||
|
||||
template <typename Message>
|
||||
using AddFunction = void (Message::*)(const ::std::string &value);
|
||||
/// Add a vector of type ID to protobuf message.
|
||||
template <typename ID, typename Message>
|
||||
inline void IdVectorToProtobuf(const std::vector<ID> &ids, Message &message,
|
||||
AddFunction<Message> add_func) {
|
||||
for (const auto &id : ids) {
|
||||
(message.*add_func)(id.Binary());
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a Protobuf `RepeatedField` to a vector of IDs.
|
||||
template <class ID>
|
||||
inline std::vector<ID> IdVectorFromProtobuf(
|
||||
|
||||
@@ -20,12 +20,8 @@ RAY_CONFIG(int64_t, ray_cookie, 0x5241590000000000)
|
||||
/// warning is logged that the handler is taking too long.
|
||||
RAY_CONFIG(int64_t, handler_warning_timeout_ms, 100)
|
||||
|
||||
/// The duration between heartbeats. This value is used for both worker and raylet.
|
||||
/// The duration between heartbeats. These are sent by the raylet.
|
||||
RAY_CONFIG(int64_t, heartbeat_timeout_milliseconds, 100)
|
||||
/// Worker heartbeats also use `heartbeat_timeout_milliseconds` as timer timeout period.
|
||||
/// If a worker has not sent a heartbeat in the last `num_worker_heartbeats_timeout`
|
||||
/// heartbeat intervals, raylet will mark this worker as dead.
|
||||
RAY_CONFIG(int64_t, num_worker_heartbeats_timeout, 30)
|
||||
/// If a component has not sent a heartbeat in the last num_heartbeats_timeout
|
||||
/// heartbeat intervals, the raylet monitor process will report
|
||||
/// it as dead to the db_client table.
|
||||
@@ -156,9 +152,6 @@ RAY_CONFIG(int32_t, num_actor_checkpoints_to_keep, 20)
|
||||
/// Maximum number of ids in one batch to send to GCS to delete keys.
|
||||
RAY_CONFIG(uint32_t, maximum_gcs_deletion_batch_size, 1000)
|
||||
|
||||
/// Number of times for a raylet client to retry to register.
|
||||
RAY_CONFIG(int, num_raylet_client_retry_times, 25)
|
||||
|
||||
/// When getting objects from object store, print a warning every this number of attempts.
|
||||
RAY_CONFIG(uint32_t, object_store_get_warn_per_num_attempts, 50)
|
||||
|
||||
|
||||
@@ -674,22 +674,37 @@ std::string ResourceIdSet::ToString() const {
|
||||
return return_string;
|
||||
}
|
||||
|
||||
std::vector<rpc::ResourceIdSetInfo> ResourceIdSet::ToProtobuf() const {
|
||||
std::vector<rpc::ResourceIdSetInfo> resources;
|
||||
std::vector<flatbuffers::Offset<protocol::ResourceIdSetInfo>> ResourceIdSet::ToFlatbuf(
|
||||
flatbuffers::FlatBufferBuilder &fbb) const {
|
||||
std::vector<flatbuffers::Offset<protocol::ResourceIdSetInfo>> return_message;
|
||||
for (auto const &resource_pair : available_resources_) {
|
||||
rpc::ResourceIdSetInfo resource_id_set_info;
|
||||
resource_id_set_info.set_resource_name(resource_pair.first);
|
||||
std::vector<int64_t> resource_ids;
|
||||
std::vector<double> resource_fractions;
|
||||
for (auto whole_id : resource_pair.second.WholeIds()) {
|
||||
resource_id_set_info.add_resource_ids(whole_id);
|
||||
resource_id_set_info.add_resource_fractions(1);
|
||||
resource_ids.push_back(whole_id);
|
||||
resource_fractions.push_back(1);
|
||||
}
|
||||
|
||||
for (auto const &fractional_pair : resource_pair.second.FractionalIds()) {
|
||||
resource_id_set_info.add_resource_ids(fractional_pair.first);
|
||||
resource_id_set_info.add_resource_fractions(fractional_pair.second.ToDouble());
|
||||
resource_ids.push_back(fractional_pair.first);
|
||||
resource_fractions.push_back(fractional_pair.second.ToDouble());
|
||||
}
|
||||
resources.emplace_back(resource_id_set_info);
|
||||
|
||||
auto resource_id_set_message = protocol::CreateResourceIdSetInfo(
|
||||
fbb, fbb.CreateString(resource_pair.first), fbb.CreateVector(resource_ids),
|
||||
fbb.CreateVector(resource_fractions));
|
||||
|
||||
return_message.push_back(resource_id_set_message);
|
||||
}
|
||||
return resources;
|
||||
|
||||
return return_message;
|
||||
}
|
||||
|
||||
const std::string ResourceIdSet::Serialize() const {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto resource_id_set_flatbuf = ToFlatbuf(fbb);
|
||||
fbb.Finish(fbb.CreateVector(resource_id_set_flatbuf));
|
||||
return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize());
|
||||
}
|
||||
|
||||
/// SchedulingResources class implementation
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "ray/protobuf/common.pb.h"
|
||||
#include "ray/raylet/format/node_manager_generated.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
@@ -422,10 +422,18 @@ class ResourceIdSet {
|
||||
/// \return A human-readable string version of the object.
|
||||
std::string ToString() const;
|
||||
|
||||
/// \brief Convert this object to a vector of protobuf `ResourceIdSetInfo`s.
|
||||
/// \brief Serialize this object using flatbuffers.
|
||||
///
|
||||
/// \return A vector inclusing resource id set infos.
|
||||
std::vector<rpc::ResourceIdSetInfo> ToProtobuf() const;
|
||||
/// \param fbb A flatbuffer builder object.
|
||||
/// \return A flatbuffer serialized version of this object.
|
||||
std::vector<flatbuffers::Offset<ray::protocol::ResourceIdSetInfo>> ToFlatbuf(
|
||||
flatbuffers::FlatBufferBuilder &fbb) const;
|
||||
|
||||
/// \brief Serialize this object as a string.
|
||||
///
|
||||
/// \return A serialized string of this object.
|
||||
/// TODO(zhijunfu): this can be removed after raylet client is migrated to grpc.
|
||||
const std::string Serialize() const;
|
||||
|
||||
private:
|
||||
/// A mapping from resource name to a set of resource IDs for that resource.
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
|
||||
#include <inttypes.h>
|
||||
|
||||
#include "ray/common/task/task_common.h"
|
||||
#include "ray/common/task/task_execution_spec.h"
|
||||
#include "ray/common/task/task_spec.h"
|
||||
#include "ray/protobuf/common.pb.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "ray/common/buffer.h"
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/task/task_spec.h"
|
||||
#include "ray/rpc/raylet/raylet_client.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
#include "ray/util/util.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
@@ -8,12 +8,10 @@
|
||||
#include "ray/core_worker/task_execution.h"
|
||||
#include "ray/core_worker/task_interface.h"
|
||||
#include "ray/gcs/redis_gcs_client.h"
|
||||
#include "ray/rpc/raylet/raylet_client.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
using rpc::RayletClient;
|
||||
|
||||
/// The root class that contains all the core and language-independent functionalities
|
||||
/// of the worker. This class is supposed to be used to implement app-language (Java,
|
||||
/// Python, etc) workers.
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
#include "ray/core_worker/common.h"
|
||||
#include "ray/core_worker/core_worker.h"
|
||||
#include "ray/core_worker/lib/java/jni_utils.h"
|
||||
#include "ray/rpc/raylet/raylet_client.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
|
||||
inline ray::RayletClient &GetRayletClientFromPointer(jlong nativeCoreWorkerPointer) {
|
||||
inline RayletClient &GetRayletClientFromPointer(jlong nativeCoreWorkerPointer) {
|
||||
return reinterpret_cast<ray::CoreWorker *>(nativeCoreWorkerPointer)->GetRayletClient();
|
||||
}
|
||||
|
||||
|
||||
@@ -11,8 +11,6 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
using rpc::RayletClient;
|
||||
|
||||
class CoreWorker;
|
||||
class CoreWorkerStoreProvider;
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/core_worker/common.h"
|
||||
#include "ray/core_worker/store_provider/store_provider.h"
|
||||
#include "ray/rpc/raylet/raylet_client.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
|
||||
@@ -8,12 +8,10 @@
|
||||
#include "ray/core_worker/common.h"
|
||||
#include "ray/core_worker/store_provider/local_plasma_provider.h"
|
||||
#include "ray/core_worker/store_provider/store_provider.h"
|
||||
#include "ray/rpc/raylet/raylet_client.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
using rpc::RayletClient;
|
||||
|
||||
class CoreWorker;
|
||||
|
||||
/// The class provides implementations for accessing plasma store, which includes both
|
||||
|
||||
@@ -17,8 +17,6 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
using rpc::RayletClient;
|
||||
|
||||
class CoreWorker;
|
||||
|
||||
/// Options of a non-actor-creation task.
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "ray/core_worker/store_provider/local_plasma_provider.h"
|
||||
#include "ray/core_worker/store_provider/memory_store_provider.h"
|
||||
|
||||
#include "ray/rpc/raylet/raylet_client.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
#include "src/ray/protobuf/direct_actor.grpc.pb.h"
|
||||
#include "src/ray/protobuf/direct_actor.pb.h"
|
||||
#include "src/ray/util/test_util.h"
|
||||
|
||||
@@ -5,13 +5,11 @@
|
||||
|
||||
#include "ray/core_worker/object_interface.h"
|
||||
#include "ray/core_worker/transport/transport.h"
|
||||
#include "ray/rpc/raylet/raylet_client.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
#include "ray/rpc/worker/worker_server.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
using rpc::RayletClient;
|
||||
|
||||
/// In raylet task submitter and receiver, a task is submitted to raylet, and possibly
|
||||
/// gets forwarded to another raylet on which node the task should be executed, and
|
||||
/// then a worker on that node gets this task and starts executing it.
|
||||
|
||||
@@ -769,8 +769,9 @@ void ObjectManager::SpreadFreeObjectsRequest(
|
||||
const std::vector<std::shared_ptr<rpc::ObjectManagerClient>> &rpc_clients) {
|
||||
// This code path should be called from node manager.
|
||||
rpc::FreeObjectsRequest free_objects_request;
|
||||
IdVectorToProtobuf<ObjectID, rpc::FreeObjectsRequest>(
|
||||
object_ids, free_objects_request, &rpc::FreeObjectsRequest::add_object_ids);
|
||||
for (const auto &e : object_ids) {
|
||||
free_objects_request.add_object_ids(e.Binary());
|
||||
}
|
||||
|
||||
for (auto &rpc_client : rpc_clients) {
|
||||
rpc_client->FreeObjects(free_objects_request, [](const Status &status,
|
||||
|
||||
@@ -11,18 +11,6 @@ enum Language {
|
||||
CPP = 2;
|
||||
}
|
||||
|
||||
// Resource id set info.
|
||||
message ResourceIdSetInfo {
|
||||
// The name of the resource.
|
||||
bytes resource_name = 1;
|
||||
// The resource IDs reserved for this worker.
|
||||
repeated uint64 resource_ids = 2;
|
||||
// The fraction of each resource ID that is reserved for this worker. Note
|
||||
// that the length of this list must be the same as the length of
|
||||
// resource_ids.
|
||||
repeated double resource_fractions = 3;
|
||||
}
|
||||
|
||||
// Type of a worker.
|
||||
enum WorkerType {
|
||||
WORKER = 0;
|
||||
|
||||
@@ -1,215 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package ray.rpc;
|
||||
|
||||
import "src/ray/protobuf/common.proto";
|
||||
import "src/ray/protobuf/gcs.proto";
|
||||
|
||||
/// NOTE(Joey Jiang) Every request defined in this file should have a `worker_id` field,
|
||||
/// which will be used in `NodeManager::PreprocessRequest`.
|
||||
|
||||
/// Service request and reply messages.
|
||||
message RegisterClientRequest {
|
||||
// The worker id.
|
||||
bytes worker_id = 1;
|
||||
// Indicates the client is a worker or a driver.
|
||||
bool is_worker = 2;
|
||||
// The process ID of this worker.
|
||||
uint32 worker_pid = 3;
|
||||
// The job ID.
|
||||
bytes job_id = 4;
|
||||
// Language of this worker.
|
||||
Language language = 5;
|
||||
// Port that this worker is listening on.
|
||||
// If port > 0, then worker will listen to this port and wait for
|
||||
// raylet to push tasks, instead of invoking GetTask().
|
||||
int32 port = 6;
|
||||
}
|
||||
message RegisterClientReply {
|
||||
repeated int32 gpu_ids = 1;
|
||||
}
|
||||
|
||||
message SubmitTaskRequest {
|
||||
bytes worker_id = 1;
|
||||
TaskSpec task_spec = 2;
|
||||
}
|
||||
message SubmitTaskReply {
|
||||
}
|
||||
|
||||
message DisconnectClientRequest {
|
||||
bytes worker_id = 1;
|
||||
}
|
||||
message DisconnectClientReply {
|
||||
}
|
||||
|
||||
message GetTaskRequest {
|
||||
bytes worker_id = 1;
|
||||
}
|
||||
message GetTaskReply {
|
||||
// A string of bytes representing the task specification.
|
||||
bytes task_spec = 1;
|
||||
// A list of the resources reserved for this worker.
|
||||
repeated ResourceIdSetInfo fractional_resource_ids = 2;
|
||||
}
|
||||
|
||||
message TaskDoneRequest {
|
||||
bytes worker_id = 1;
|
||||
}
|
||||
message TaskDoneReply {
|
||||
}
|
||||
|
||||
message FetchOrReconstructRequest {
|
||||
// The worker ID.
|
||||
bytes worker_id = 1;
|
||||
// List of object IDs of the objects that we want to reconstruct or fetch.
|
||||
repeated bytes object_ids = 2;
|
||||
// Indicates that we only want to fetch objects, not reconstruct them.
|
||||
bool fetch_only = 3;
|
||||
// The current task ID. If fetch_only is false, then this task is blocked.
|
||||
bytes task_id = 4;
|
||||
}
|
||||
message FetchOrReconstructReply {
|
||||
}
|
||||
|
||||
message NotifyUnblockedRequest {
|
||||
bytes worker_id = 1;
|
||||
// The current task ID. This task is no longer blocked.
|
||||
bytes task_id = 2;
|
||||
}
|
||||
message NotifyUnblockedReply {
|
||||
}
|
||||
|
||||
message WaitRequest {
|
||||
// The worker ID.
|
||||
bytes worker_id = 1;
|
||||
// List of object ids we'll be waiting on.
|
||||
repeated bytes object_ids = 2;
|
||||
// Number of objects expected to be returned, if available.
|
||||
uint64 num_ready_objects = 3;
|
||||
// Timeout in milliseconds.
|
||||
int64 timeout = 4;
|
||||
// Whether to wait until objects appear locally.
|
||||
bool wait_local = 5;
|
||||
// The current task ID. If there are less than num_ready_objects local, then
|
||||
// this task is blocked.
|
||||
bytes task_id = 6;
|
||||
}
|
||||
message WaitReply {
|
||||
// List of object ids found.
|
||||
repeated bytes found = 1;
|
||||
// List of object ids not found.
|
||||
repeated bytes remaining = 2;
|
||||
}
|
||||
|
||||
message PushErrorRequest {
|
||||
// The worker ID.
|
||||
bytes worker_id = 1;
|
||||
// The job id that the error is for.
|
||||
bytes job_id = 2;
|
||||
// The type of the error.
|
||||
bytes type = 3;
|
||||
// The error message.
|
||||
bytes error_message = 4;
|
||||
// The timestamp of the error message.
|
||||
double timestamp = 5;
|
||||
}
|
||||
message PushErrorReply {
|
||||
}
|
||||
|
||||
message PushProfileEventsRequest {
|
||||
bytes worker_id = 1;
|
||||
ProfileTableData profile_table_data = 2;
|
||||
}
|
||||
message PushProfileEventsReply {
|
||||
}
|
||||
|
||||
message FreeObjectsInStoreRequest {
|
||||
// The worker ID.
|
||||
bytes worker_id = 1;
|
||||
// Whether keep this request within the local object store
|
||||
// or send it to all of the object stores.
|
||||
bool local_only = 2;
|
||||
// Whether also delete objects' creating tasks from GCS.
|
||||
bool delete_creating_tasks = 3;
|
||||
// List of object ids to delete from the object store.
|
||||
repeated bytes object_ids = 4;
|
||||
}
|
||||
message FreeObjectsInStoreReply {
|
||||
}
|
||||
|
||||
message PrepareActorCheckpointRequest {
|
||||
bytes worker_id = 1;
|
||||
bytes actor_id = 2;
|
||||
}
|
||||
message PrepareActorCheckpointReply {
|
||||
bytes worker_id = 1;
|
||||
bytes checkpoint_id = 2;
|
||||
}
|
||||
|
||||
message NotifyActorResumedFromCheckpointRequest {
|
||||
bytes worker_id = 1;
|
||||
// ID of the actor that resumed.
|
||||
bytes actor_id = 2;
|
||||
// ID of the checkpoint from which the actor was resumed.
|
||||
bytes checkpoint_id = 3;
|
||||
}
|
||||
message NotifyActorResumedFromCheckpointReply {
|
||||
}
|
||||
|
||||
message SetResourceRequest {
|
||||
bytes worker_id = 1;
|
||||
// Name of the resource to be set.
|
||||
bytes resource_name = 2;
|
||||
// Capacity of the resource to be set.
|
||||
double capacity = 3;
|
||||
// Client ID where this resource will be set.
|
||||
bytes client_id = 4;
|
||||
}
|
||||
message SetResourceReply {
|
||||
}
|
||||
|
||||
message HeartbeatRequest {
|
||||
bytes worker_id = 1;
|
||||
bool is_worker = 2;
|
||||
}
|
||||
message HeartbeatReply {
|
||||
}
|
||||
|
||||
/// Worker-to-raylet RPC service interface.
|
||||
service RayletService {
|
||||
// Register a new worker to the raylet.
|
||||
rpc RegisterClient(RegisterClientRequest) returns (RegisterClientReply);
|
||||
// Submit a task to the raylet.
|
||||
rpc SubmitTask(SubmitTaskRequest) returns (SubmitTaskReply);
|
||||
// Disconnect this client from raylet gracefully.
|
||||
rpc DisconnectClient(DisconnectClientRequest) returns (DisconnectClientReply);
|
||||
// Get a new task from the raylet.
|
||||
rpc GetTask(GetTaskRequest) returns (GetTaskReply);
|
||||
// Notify the raylet that a task is finished.
|
||||
rpc TaskDone(TaskDoneRequest) returns (TaskDoneReply);
|
||||
// Reconstruct or fetch possibly lost objects.
|
||||
rpc FetchOrReconstruct(FetchOrReconstructRequest) returns (FetchOrReconstructReply);
|
||||
// For a worker that was blocked on some object(s), tell the raylet
|
||||
// that the worker is now unblocked.
|
||||
rpc NotifyUnblocked(NotifyUnblockedRequest) returns (NotifyUnblockedReply);
|
||||
// Wait for objects to be ready either from local or remote plasma stores.
|
||||
// The `WaitReply` contains the objects found and objects remaining.
|
||||
rpc Wait(WaitRequest) returns (WaitReply);
|
||||
// Push an error to the relevant driver.
|
||||
rpc PushError(PushErrorRequest) returns (PushErrorReply);
|
||||
// Push some profiling events to the GCS. When sending this message to the
|
||||
// node manager, the message itself is serialized as a ProfileTableData object.
|
||||
rpc PushProfileEvents(PushProfileEventsRequest) returns (PushProfileEventsReply);
|
||||
// Free the objects in plasma objects store.
|
||||
rpc FreeObjectsInStore(FreeObjectsInStoreRequest) returns (FreeObjectsInStoreReply);
|
||||
// Request raylet backend to prepare a checkpoint for an actor.
|
||||
rpc PrepareActorCheckpoint(PrepareActorCheckpointRequest)
|
||||
returns (PrepareActorCheckpointReply);
|
||||
// Notify raylet backend that an actor was resumed from a checkpoint.
|
||||
rpc NotifyActorResumedFromCheckpoint(NotifyActorResumedFromCheckpointRequest)
|
||||
returns (NotifyActorResumedFromCheckpointReply);
|
||||
// Set dynamic custom resource.
|
||||
rpc SetResource(SetResourceRequest) returns (SetResourceReply);
|
||||
// Send a heartbeat message to raylet.
|
||||
rpc Heartbeat(HeartbeatRequest) returns (HeartbeatReply);
|
||||
}
|
||||
@@ -8,7 +8,9 @@ message AssignTaskRequest {
|
||||
// The task to be pushed.
|
||||
Task task = 1;
|
||||
// A list of the resources reserved for this worker.
|
||||
repeated ResourceIdSetInfo resource_ids = 2;
|
||||
// TODO(zhijunfu): `resource_ids` is represented as
|
||||
// flatbutters-serialized bytes, will be moved to protobuf later.
|
||||
bytes resource_ids = 2;
|
||||
}
|
||||
|
||||
message AssignTaskReply {
|
||||
|
||||
@@ -0,0 +1,292 @@
|
||||
#include "ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h"
|
||||
|
||||
#include <jni.h>
|
||||
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/core_worker/lib/java/jni_utils.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
#include "ray/util/logging.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeInit
|
||||
* Signature: (Ljava/lang/String;[BZ[B)J
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit(
|
||||
JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker,
|
||||
jbyteArray jobId) {
|
||||
const auto worker_id = JavaByteArrayToId<WorkerID>(env, workerId);
|
||||
const auto job_id = JavaByteArrayToId<JobID>(env, jobId);
|
||||
const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE);
|
||||
auto raylet_client = new std::unique_ptr<RayletClient>(
|
||||
new RayletClient(nativeString, worker_id, isWorker, job_id, Language::JAVA));
|
||||
env->ReleaseStringUTFChars(sockName, nativeString);
|
||||
return reinterpret_cast<jlong>(raylet_client);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeSubmitTask
|
||||
* Signature: (J[BLjava/nio/ByteBuffer;II)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask(
|
||||
JNIEnv *env, jclass, jlong client, jbyteArray taskSpec) {
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
|
||||
jbyte *data = env->GetByteArrayElements(taskSpec, NULL);
|
||||
jsize size = env->GetArrayLength(taskSpec);
|
||||
ray::rpc::TaskSpec task_spec_message;
|
||||
task_spec_message.ParseFromArray(data, size);
|
||||
env->ReleaseByteArrayElements(taskSpec, data, JNI_ABORT);
|
||||
|
||||
ray::TaskSpecification task_spec(task_spec_message);
|
||||
auto status = raylet_client->SubmitTask(task_spec);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeGetTask
|
||||
* Signature: (J)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeGetTask(
|
||||
JNIEnv *env, jclass, jlong client) {
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
|
||||
std::unique_ptr<ray::TaskSpecification> spec;
|
||||
auto status = raylet_client->GetTask(&spec);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
|
||||
// Serialize the task spec and copy to Java byte array.
|
||||
auto task_data = spec->Serialize();
|
||||
|
||||
jbyteArray result = env->NewByteArray(task_data.size());
|
||||
if (result == nullptr) {
|
||||
return nullptr; /* out of memory error thrown */
|
||||
}
|
||||
|
||||
env->SetByteArrayRegion(result, 0, task_data.size(),
|
||||
reinterpret_cast<const jbyte *>(task_data.data()));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeDestroy
|
||||
* Signature: (J)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy(
|
||||
JNIEnv *env, jclass, jlong client) {
|
||||
auto raylet_client = reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
auto status = (*raylet_client)->Disconnect();
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
delete raylet_client;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeWaitObject
|
||||
* Signature: (J[[BIIZ[B)[Z
|
||||
*/
|
||||
JNIEXPORT jbooleanArray JNICALL
|
||||
Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject(
|
||||
JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jint numReturns,
|
||||
jint timeoutMillis, jboolean isWaitLocal, jbyteArray currentTaskId) {
|
||||
std::vector<ObjectID> object_ids;
|
||||
auto len = env->GetArrayLength(objectIds);
|
||||
for (int i = 0; i < len; i++) {
|
||||
jbyteArray object_id_bytes =
|
||||
static_cast<jbyteArray>(env->GetObjectArrayElement(objectIds, i));
|
||||
const auto object_id = JavaByteArrayToId<ObjectID>(env, object_id_bytes);
|
||||
object_ids.push_back(object_id);
|
||||
env->DeleteLocalRef(object_id_bytes);
|
||||
}
|
||||
const auto current_task_id = JavaByteArrayToId<TaskID>(env, currentTaskId);
|
||||
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
|
||||
// Invoke wait.
|
||||
WaitResultPair result;
|
||||
auto status =
|
||||
raylet_client->Wait(object_ids, numReturns, timeoutMillis,
|
||||
static_cast<bool>(isWaitLocal), current_task_id, &result);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
|
||||
// Convert result to java object.
|
||||
jboolean put_value = true;
|
||||
jbooleanArray resultArray = env->NewBooleanArray(object_ids.size());
|
||||
for (uint i = 0; i < result.first.size(); ++i) {
|
||||
for (uint j = 0; j < object_ids.size(); ++j) {
|
||||
if (result.first[i] == object_ids[j]) {
|
||||
env->SetBooleanArrayRegion(resultArray, j, 1, &put_value);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
put_value = false;
|
||||
for (uint i = 0; i < result.second.size(); ++i) {
|
||||
for (uint j = 0; j < object_ids.size(); ++j) {
|
||||
if (result.second[i] == object_ids[j]) {
|
||||
env->SetBooleanArrayRegion(resultArray, j, 1, &put_value);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return resultArray;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeGenerateActorCreationTaskId
|
||||
* Signature: ([B[BI)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorCreationTaskId(
|
||||
JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId,
|
||||
jint parent_task_counter) {
|
||||
const auto job_id = JavaByteArrayToId<JobID>(env, jobId);
|
||||
const auto parent_task_id = JavaByteArrayToId<TaskID>(env, parentTaskId);
|
||||
|
||||
const ActorID actor_id = ray::ActorID::Of(job_id, parent_task_id, parent_task_counter);
|
||||
const TaskID actor_creation_task_id = ray::TaskID::ForActorCreationTask(actor_id);
|
||||
jbyteArray result = env->NewByteArray(actor_creation_task_id.Size());
|
||||
if (nullptr == result) {
|
||||
return nullptr;
|
||||
}
|
||||
env->SetByteArrayRegion(result, 0, actor_creation_task_id.Size(),
|
||||
reinterpret_cast<const jbyte *>(actor_creation_task_id.Data()));
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeGenerateActorTaskId
|
||||
* Signature: ([B[BI[B)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorTaskId(
|
||||
JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId,
|
||||
jint parent_task_counter, jbyteArray actorId) {
|
||||
const auto job_id = JavaByteArrayToId<JobID>(env, jobId);
|
||||
const auto parent_task_id = JavaByteArrayToId<TaskID>(env, parentTaskId);
|
||||
const auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
|
||||
const TaskID actor_task_id =
|
||||
ray::TaskID::ForActorTask(job_id, parent_task_id, parent_task_counter, actor_id);
|
||||
|
||||
jbyteArray result = env->NewByteArray(actor_task_id.Size());
|
||||
if (nullptr == result) {
|
||||
return nullptr;
|
||||
}
|
||||
env->SetByteArrayRegion(result, 0, actor_task_id.Size(),
|
||||
reinterpret_cast<const jbyte *>(actor_task_id.Data()));
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeGenerateNormalTaskId
|
||||
* Signature: ([B[BI)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateNormalTaskId(
|
||||
JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId,
|
||||
jint parent_task_counter) {
|
||||
const auto job_id = JavaByteArrayToId<JobID>(env, jobId);
|
||||
const auto parent_task_id = JavaByteArrayToId<TaskID>(env, parentTaskId);
|
||||
const TaskID task_id =
|
||||
ray::TaskID::ForNormalTask(job_id, parent_task_id, parent_task_counter);
|
||||
|
||||
jbyteArray result = env->NewByteArray(task_id.Size());
|
||||
if (nullptr == result) {
|
||||
return nullptr;
|
||||
}
|
||||
env->SetByteArrayRegion(result, 0, task_id.Size(),
|
||||
reinterpret_cast<const jbyte *>(task_id.Data()));
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeFreePlasmaObjects
|
||||
* Signature: (J[[BZZ)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects(
|
||||
JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jboolean localOnly,
|
||||
jboolean deleteCreatingTasks) {
|
||||
std::vector<ObjectID> object_ids;
|
||||
auto len = env->GetArrayLength(objectIds);
|
||||
for (int i = 0; i < len; i++) {
|
||||
jbyteArray object_id_bytes =
|
||||
static_cast<jbyteArray>(env->GetObjectArrayElement(objectIds, i));
|
||||
const auto object_id = JavaByteArrayToId<ObjectID>(env, object_id_bytes);
|
||||
object_ids.push_back(object_id);
|
||||
env->DeleteLocalRef(object_id_bytes);
|
||||
}
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
auto status = raylet_client->FreeObjects(object_ids, localOnly, deleteCreatingTasks);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativePrepareCheckpoint
|
||||
* Signature: (J[B)[B
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env, jclass,
|
||||
jlong client,
|
||||
jbyteArray actorId) {
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
const auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
|
||||
ActorCheckpointID checkpoint_id;
|
||||
auto status = raylet_client->PrepareActorCheckpoint(actor_id, checkpoint_id);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr);
|
||||
jbyteArray result = env->NewByteArray(checkpoint_id.Size());
|
||||
env->SetByteArrayRegion(result, 0, checkpoint_id.Size(),
|
||||
reinterpret_cast<const jbyte *>(checkpoint_id.Data()));
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeNotifyActorResumedFromCheckpoint
|
||||
* Signature: (J[B[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL
|
||||
Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint(
|
||||
JNIEnv *env, jclass, jlong client, jbyteArray actorId, jbyteArray checkpointId) {
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
const auto actor_id = JavaByteArrayToId<ActorID>(env, actorId);
|
||||
const auto checkpoint_id = JavaByteArrayToId<ActorCheckpointID>(env, checkpointId);
|
||||
auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Class: org_ray_runtime_raylet_RayletClientImpl
|
||||
* Method: nativeSetResource
|
||||
* Signature: (JLjava/lang/String;D[B)V
|
||||
*/
|
||||
JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource(
|
||||
JNIEnv *env, jclass, jlong client, jstring resourceName, jdouble capacity,
|
||||
jbyteArray nodeId) {
|
||||
auto &raylet_client = *reinterpret_cast<std::unique_ptr<RayletClient> *>(client);
|
||||
const auto node_id = JavaByteArrayToId<ClientID>(env, nodeId);
|
||||
const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE);
|
||||
|
||||
auto status = raylet_client->SetResource(native_resource_name,
|
||||
static_cast<double>(capacity), node_id);
|
||||
env->ReleaseStringUTFChars(resourceName, native_resource_name);
|
||||
THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0);
|
||||
}
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ray/common/task/task_execution_spec.h"
|
||||
#include "ray/common/task/task_spec.h"
|
||||
#include "ray/common/task/task_util.h"
|
||||
#include "ray/raylet/format/node_manager_generated.h"
|
||||
#include "ray/raylet/lineage_cache.h"
|
||||
#include "ray/util/test_util.h"
|
||||
|
||||
|
||||
+377
-356
File diff suppressed because it is too large
Load Diff
@@ -7,8 +7,6 @@
|
||||
#include "ray/rpc/client_call.h"
|
||||
#include "ray/rpc/node_manager/node_manager_server.h"
|
||||
#include "ray/rpc/node_manager/node_manager_client.h"
|
||||
#include "ray/rpc/raylet/raylet_server.h"
|
||||
#include "ray/object_manager/object_manager.h"
|
||||
#include "ray/common/task/task.h"
|
||||
#include "ray/common/client_connection.h"
|
||||
#include "ray/common/task/task_common.h"
|
||||
@@ -66,8 +64,7 @@ struct NodeManagerConfig {
|
||||
std::string session_dir;
|
||||
};
|
||||
|
||||
class NodeManager : public rpc::NodeManagerServiceHandler,
|
||||
public rpc::RayletServiceHandler {
|
||||
class NodeManager : public rpc::NodeManagerServiceHandler {
|
||||
public:
|
||||
/// Create a node manager.
|
||||
///
|
||||
@@ -78,6 +75,23 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
|
||||
std::shared_ptr<gcs::RedisGcsClient> gcs_client,
|
||||
std::shared_ptr<ObjectDirectoryInterface> object_directory_);
|
||||
|
||||
/// Process a new client connection.
|
||||
///
|
||||
/// \param client The client to process.
|
||||
/// \return Void.
|
||||
void ProcessNewClient(LocalClientConnection &client);
|
||||
|
||||
/// Process a message from a client. This method is responsible for
|
||||
/// explicitly listening for more messages from the client if the client is
|
||||
/// still alive.
|
||||
///
|
||||
/// \param client The client that sent the message.
|
||||
/// \param message_type The message type (e.g., a flatbuffer enum).
|
||||
/// \param message_data A pointer to the message data.
|
||||
/// \return Void.
|
||||
void ProcessClientMessage(const std::shared_ptr<LocalClientConnection> &client,
|
||||
int64_t message_type, const uint8_t *message_data);
|
||||
|
||||
/// Subscribe to the relevant GCS tables and set up handlers.
|
||||
///
|
||||
/// \return Status indicating whether this was done successfully or not.
|
||||
@@ -91,91 +105,9 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
|
||||
/// Record metrics.
|
||||
void RecordMetrics() const;
|
||||
|
||||
public:
|
||||
/// Get the port of the node manager rpc server.
|
||||
int GetServerPort() const { return node_manager_server_.GetPort(); }
|
||||
|
||||
/// Preprocess request from raylet client. We will check whether the worker is being
|
||||
/// killed due to job finishing.
|
||||
///
|
||||
/// \param worker_id The worker id.
|
||||
/// \param request_name The request name.
|
||||
/// \return False if there is no need to process this request.
|
||||
bool PreprocessRequest(const WorkerID &worker_id, const std::string &request_name);
|
||||
|
||||
/// Implementation of node manager grpc service.
|
||||
|
||||
/// Handle a `ForwardTask` request.
|
||||
///
|
||||
/// \param request The request.
|
||||
/// \param reply The reply that will be sent to client.
|
||||
/// \param send_reply_callback Invoke this callback to send reply asynchronously.
|
||||
void HandleForwardTask(const rpc::ForwardTaskRequest &request,
|
||||
rpc::ForwardTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
/// Implementation of raylet grpc service handlers, please see definitions
|
||||
/// in `src/ray/protobuf/raylet.proto` for more details.
|
||||
void HandleRegisterClientRequest(const rpc::RegisterClientRequest &request,
|
||||
rpc::RegisterClientReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandleSubmitTaskRequest(const rpc::SubmitTaskRequest &request,
|
||||
rpc::SubmitTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandleDisconnectClientRequest(const rpc::DisconnectClientRequest &request,
|
||||
rpc::DisconnectClientReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandleGetTaskRequest(const rpc::GetTaskRequest &request, rpc::GetTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandleTaskDoneRequest(const rpc::TaskDoneRequest &request,
|
||||
rpc::TaskDoneReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandleFetchOrReconstructRequest(
|
||||
const rpc::FetchOrReconstructRequest &request, rpc::FetchOrReconstructReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandleNotifyUnblockedRequest(const rpc::NotifyUnblockedRequest &request,
|
||||
rpc::NotifyUnblockedReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandleWaitRequest(const rpc::WaitRequest &request, rpc::WaitReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandlePushErrorRequest(const rpc::PushErrorRequest &request,
|
||||
rpc::PushErrorReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandlePushProfileEventsRequest(
|
||||
const rpc::PushProfileEventsRequest &request, rpc::PushProfileEventsReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandleFreeObjectsInStoreRequest(
|
||||
const rpc::FreeObjectsInStoreRequest &request, rpc::FreeObjectsInStoreReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandlePrepareActorCheckpointRequest(
|
||||
const rpc::PrepareActorCheckpointRequest &request,
|
||||
rpc::PrepareActorCheckpointReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandleNotifyActorResumedFromCheckpointRequest(
|
||||
const rpc::NotifyActorResumedFromCheckpointRequest &request,
|
||||
rpc::NotifyActorResumedFromCheckpointReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandleSetResourceRequest(const rpc::SetResourceRequest &request,
|
||||
rpc::SetResourceReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
void HandleHeartbeatRequest(const rpc::HeartbeatRequest &request,
|
||||
rpc::HeartbeatReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
private:
|
||||
/// Methods for handling clients.
|
||||
|
||||
@@ -276,11 +208,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
|
||||
/// \return Void.
|
||||
void SubmitTask(const Task &task, const Lineage &uncommitted_lineage,
|
||||
bool forwarded = false);
|
||||
/// Handle the case that a worker is available.
|
||||
///
|
||||
/// \param id Id of the worker.
|
||||
/// \return Void.
|
||||
void HandleWorkerAvailable(const WorkerID &worker_id);
|
||||
/// Assign a task. The task is assumed to not be queued in local_queues_.
|
||||
///
|
||||
/// \param task The task in question.
|
||||
@@ -392,7 +319,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
|
||||
/// \param ray_get Whether the task is blocked in a `ray.get` call, as
|
||||
/// opposed to a `ray.wait` call.
|
||||
/// \return Void.
|
||||
void HandleTaskBlocked(const WorkerID &worker_id,
|
||||
void HandleTaskBlocked(const std::shared_ptr<LocalClientConnection> &client,
|
||||
const std::vector<ObjectID> &required_object_ids,
|
||||
const TaskID ¤t_task_id, bool ray_get);
|
||||
|
||||
@@ -405,7 +332,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
|
||||
/// \param client The client that is executing the unblocked task.
|
||||
/// \param current_task_id The task that is unblocked.
|
||||
/// \return Void.
|
||||
void HandleTaskUnblocked(const WorkerID &worker_id, const TaskID ¤t_task_id);
|
||||
void HandleTaskUnblocked(const std::shared_ptr<LocalClientConnection> &client,
|
||||
const TaskID ¤t_task_id);
|
||||
|
||||
/// Kill a worker.
|
||||
///
|
||||
@@ -481,8 +409,9 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
|
||||
/// \param client The client that sent the message.
|
||||
/// \param intentional_disconnect Whether the client was intentionally disconnected.
|
||||
/// \return Void.
|
||||
void ProcessDisconnectClientMessage(const WorkerID &worker_id,
|
||||
bool intentional_disconnect = false);
|
||||
void ProcessDisconnectClientMessage(
|
||||
const std::shared_ptr<LocalClientConnection> &client,
|
||||
bool intentional_disconnect = false);
|
||||
|
||||
/// Process client message of SubmitTask
|
||||
///
|
||||
@@ -550,6 +479,19 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
|
||||
void HandleDisconnectedActor(const ActorID &actor_id, bool was_local,
|
||||
bool intentional_disconnect);
|
||||
|
||||
/// Finish assigning a task to a worker.
|
||||
///
|
||||
/// \param task_id Id of the task.
|
||||
/// \param worker Worker which the task is assigned to.
|
||||
/// \param success Whether the task is successfully assigned to the worker.
|
||||
/// \return void.
|
||||
void FinishAssignTask(const TaskID &task_id, Worker &worker, bool success);
|
||||
|
||||
/// Handle a `ForwardTask` request.
|
||||
void HandleForwardTask(const rpc::ForwardTaskRequest &request,
|
||||
rpc::ForwardTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) override;
|
||||
|
||||
// GCS client ID for this node.
|
||||
ClientID client_id_;
|
||||
boost::asio::io_service &io_service_;
|
||||
@@ -609,9 +551,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler,
|
||||
/// The node manager RPC service.
|
||||
rpc::NodeManagerGrpcService node_manager_service_;
|
||||
|
||||
/// The raylet RPC service.
|
||||
rpc::RayletGrpcService raylet_service_;
|
||||
|
||||
/// The `ClientCallManager` object that is shared by all `NodeManagerClient`s
|
||||
/// as well as all `WorkerTaskClient`s.
|
||||
rpc::ClientCallManager client_call_manager_;
|
||||
|
||||
@@ -7,6 +7,34 @@
|
||||
|
||||
#include "ray/common/status.h"
|
||||
|
||||
namespace {
|
||||
|
||||
const std::vector<std::string> GenerateEnumNames(const char *const *enum_names_ptr,
|
||||
int start_index, int end_index) {
|
||||
std::vector<std::string> enum_names;
|
||||
for (int i = 0; i < start_index; ++i) {
|
||||
enum_names.push_back("EmptyMessageType");
|
||||
}
|
||||
size_t i = 0;
|
||||
while (true) {
|
||||
const char *name = enum_names_ptr[i];
|
||||
if (name == nullptr) {
|
||||
break;
|
||||
}
|
||||
enum_names.push_back(name);
|
||||
i++;
|
||||
}
|
||||
RAY_CHECK(static_cast<size_t>(end_index) == enum_names.size() - 1)
|
||||
<< "Message Type mismatch!";
|
||||
return enum_names;
|
||||
}
|
||||
|
||||
static const std::vector<std::string> node_manager_message_enum =
|
||||
GenerateEnumNames(ray::protocol::EnumNamesMessageType(),
|
||||
static_cast<int>(ray::protocol::MessageType::MIN),
|
||||
static_cast<int>(ray::protocol::MessageType::MAX));
|
||||
} // namespace
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
@@ -23,10 +51,10 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_
|
||||
node_manager_(main_service, node_manager_config, object_manager_, gcs_client_,
|
||||
object_directory_),
|
||||
socket_name_(socket_name),
|
||||
raylet_server_("Raylet", socket_name),
|
||||
raylet_service_(main_service, node_manager_) {
|
||||
raylet_server_.RegisterService(raylet_service_);
|
||||
raylet_server_.Run();
|
||||
acceptor_(main_service, boost::asio::local::stream_protocol::endpoint(socket_name)),
|
||||
socket_(main_service) {
|
||||
// Start listening for clients.
|
||||
DoAccept();
|
||||
|
||||
RAY_CHECK_OK(RegisterGcs(
|
||||
node_ip_address, socket_name_, object_manager_config.store_socket_name,
|
||||
@@ -80,6 +108,31 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void Raylet::DoAccept() {
|
||||
acceptor_.async_accept(socket_, boost::bind(&Raylet::HandleAccept, this,
|
||||
boost::asio::placeholders::error));
|
||||
}
|
||||
|
||||
void Raylet::HandleAccept(const boost::system::error_code &error) {
|
||||
if (!error) {
|
||||
// TODO: typedef these handlers.
|
||||
ClientHandler<boost::asio::local::stream_protocol> client_handler =
|
||||
[this](LocalClientConnection &client) { node_manager_.ProcessNewClient(client); };
|
||||
MessageHandler<boost::asio::local::stream_protocol> message_handler =
|
||||
[this](std::shared_ptr<LocalClientConnection> client, int64_t message_type,
|
||||
const uint8_t *message) {
|
||||
node_manager_.ProcessClientMessage(client, message_type, message);
|
||||
};
|
||||
// Accept a new local client and dispatch it to the node manager.
|
||||
auto new_connection = LocalClientConnection::Create(
|
||||
client_handler, message_handler, std::move(socket_), "worker",
|
||||
node_manager_message_enum,
|
||||
static_cast<int64_t>(protocol::MessageType::DisconnectClient));
|
||||
}
|
||||
// We're ready to accept another client.
|
||||
DoAccept();
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -75,12 +75,10 @@ class Raylet {
|
||||
/// The name of the socket this raylet listens on.
|
||||
std::string socket_name_;
|
||||
|
||||
/// The gRPC server, listens for local raylet client connections through a unix domain
|
||||
/// socket.
|
||||
rpc::GrpcServer raylet_server_;
|
||||
|
||||
/// The gRPC service.
|
||||
rpc::RayletGrpcService raylet_service_;
|
||||
/// An acceptor for new clients.
|
||||
boost::asio::local::stream_protocol::acceptor acceptor_;
|
||||
/// The socket to listen on for new clients.
|
||||
boost::asio::local::stream_protocol::socket socket_;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
@@ -0,0 +1,392 @@
|
||||
#include "raylet_client.h"
|
||||
|
||||
#include <inttypes.h>
|
||||
#include <netdb.h>
|
||||
#include <netinet/in.h>
|
||||
#include <stdarg.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <sys/ioctl.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/un.h>
|
||||
|
||||
#include "ray/common/common_protocol.h"
|
||||
#include "ray/common/ray_config.h"
|
||||
#include "ray/common/task/task_spec.h"
|
||||
#include "ray/raylet/format/node_manager_generated.h"
|
||||
#include "ray/util/logging.h"
|
||||
|
||||
using MessageType = ray::protocol::MessageType;
|
||||
|
||||
// TODO(rkn): The io methods below should be removed.
|
||||
int connect_ipc_sock(const std::string &socket_pathname) {
|
||||
struct sockaddr_un socket_address;
|
||||
int socket_fd;
|
||||
|
||||
socket_fd = socket(AF_UNIX, SOCK_STREAM, 0);
|
||||
if (socket_fd < 0) {
|
||||
RAY_LOG(ERROR) << "socket() failed for pathname " << socket_pathname;
|
||||
return -1;
|
||||
}
|
||||
|
||||
memset(&socket_address, 0, sizeof(socket_address));
|
||||
socket_address.sun_family = AF_UNIX;
|
||||
if (socket_pathname.length() + 1 > sizeof(socket_address.sun_path)) {
|
||||
RAY_LOG(ERROR) << "Socket pathname is too long.";
|
||||
close(socket_fd);
|
||||
return -1;
|
||||
}
|
||||
strncpy(socket_address.sun_path, socket_pathname.c_str(), socket_pathname.length() + 1);
|
||||
|
||||
if (connect(socket_fd, (struct sockaddr *)&socket_address, sizeof(socket_address)) !=
|
||||
0) {
|
||||
close(socket_fd);
|
||||
return -1;
|
||||
}
|
||||
return socket_fd;
|
||||
}
|
||||
|
||||
int read_bytes(int socket_fd, uint8_t *cursor, size_t length) {
|
||||
ssize_t nbytes = 0;
|
||||
// Termination condition: EOF or read 'length' bytes total.
|
||||
size_t bytesleft = length;
|
||||
size_t offset = 0;
|
||||
while (bytesleft > 0) {
|
||||
nbytes = read(socket_fd, cursor + offset, bytesleft);
|
||||
if (nbytes < 0) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
|
||||
continue;
|
||||
}
|
||||
return -1; // Errno will be set.
|
||||
} else if (0 == nbytes) {
|
||||
// Encountered early EOF.
|
||||
return -1;
|
||||
}
|
||||
RAY_CHECK(nbytes > 0);
|
||||
bytesleft -= nbytes;
|
||||
offset += nbytes;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int write_bytes(int socket_fd, uint8_t *cursor, size_t length) {
|
||||
ssize_t nbytes = 0;
|
||||
size_t bytesleft = length;
|
||||
size_t offset = 0;
|
||||
while (bytesleft > 0) {
|
||||
// While we haven't written the whole message, write to the file
|
||||
// descriptor, advance the cursor, and decrease the amount left to write.
|
||||
nbytes = write(socket_fd, cursor + offset, bytesleft);
|
||||
if (nbytes < 0) {
|
||||
if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) {
|
||||
continue;
|
||||
}
|
||||
return -1; // Errno will be set.
|
||||
} else if (0 == nbytes) {
|
||||
// Encountered early EOF.
|
||||
return -1;
|
||||
}
|
||||
RAY_CHECK(nbytes > 0);
|
||||
bytesleft -= nbytes;
|
||||
offset += nbytes;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
RayletConnection::RayletConnection(const std::string &raylet_socket, int num_retries,
|
||||
int64_t timeout) {
|
||||
// Pick the default values if the user did not specify.
|
||||
if (num_retries < 0) {
|
||||
num_retries = RayConfig::instance().num_connect_attempts();
|
||||
}
|
||||
if (timeout < 0) {
|
||||
timeout = RayConfig::instance().connect_timeout_milliseconds();
|
||||
}
|
||||
RAY_CHECK(!raylet_socket.empty());
|
||||
conn_ = -1;
|
||||
for (int num_attempts = 0; num_attempts < num_retries; ++num_attempts) {
|
||||
conn_ = connect_ipc_sock(raylet_socket);
|
||||
if (conn_ >= 0) break;
|
||||
if (num_attempts > 0) {
|
||||
RAY_LOG(ERROR) << "Retrying to connect to socket for pathname " << raylet_socket
|
||||
<< " (num_attempts = " << num_attempts
|
||||
<< ", num_retries = " << num_retries << ")";
|
||||
}
|
||||
// Sleep for timeout milliseconds.
|
||||
usleep(timeout * 1000);
|
||||
}
|
||||
// If we could not connect to the socket, exit.
|
||||
if (conn_ == -1) {
|
||||
RAY_LOG(FATAL) << "Could not connect to socket " << raylet_socket;
|
||||
}
|
||||
}
|
||||
|
||||
ray::Status RayletConnection::Disconnect() {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = ray::protocol::CreateDisconnectClient(fbb);
|
||||
fbb.Finish(message);
|
||||
auto status = WriteMessage(MessageType::IntentionalDisconnectClient, &fbb);
|
||||
// Don't be too strict for disconnection errors.
|
||||
// Just create logs and prevent it from crash.
|
||||
if (!status.ok()) {
|
||||
RAY_LOG(ERROR) << status.ToString()
|
||||
<< " [RayletClient] Failed to disconnect from raylet.";
|
||||
}
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
ray::Status RayletConnection::ReadMessage(MessageType type,
|
||||
std::unique_ptr<uint8_t[]> &message) {
|
||||
int64_t cookie;
|
||||
int64_t type_field;
|
||||
int64_t length;
|
||||
int closed = read_bytes(conn_, (uint8_t *)&cookie, sizeof(cookie));
|
||||
if (closed) goto disconnected;
|
||||
RAY_CHECK(cookie == RayConfig::instance().ray_cookie());
|
||||
closed = read_bytes(conn_, (uint8_t *)&type_field, sizeof(type_field));
|
||||
if (closed) goto disconnected;
|
||||
closed = read_bytes(conn_, (uint8_t *)&length, sizeof(length));
|
||||
if (closed) goto disconnected;
|
||||
message = std::unique_ptr<uint8_t[]>(new uint8_t[length]);
|
||||
closed = read_bytes(conn_, message.get(), length);
|
||||
if (closed) {
|
||||
// Handle the case in which the socket is closed.
|
||||
message.reset(nullptr);
|
||||
disconnected:
|
||||
message = nullptr;
|
||||
type_field = static_cast<int64_t>(MessageType::DisconnectClient);
|
||||
length = 0;
|
||||
}
|
||||
if (type_field == static_cast<int64_t>(MessageType::DisconnectClient)) {
|
||||
return ray::Status::IOError("[RayletClient] Raylet connection closed.");
|
||||
}
|
||||
if (type_field != static_cast<int64_t>(type)) {
|
||||
return ray::Status::TypeError(
|
||||
std::string("[RayletClient] Raylet connection corrupted. ") +
|
||||
"Expected message type: " + std::to_string(static_cast<int64_t>(type)) +
|
||||
"; got message type: " + std::to_string(type_field) +
|
||||
". Check logs or dmesg for previous errors.");
|
||||
}
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
ray::Status RayletConnection::WriteMessage(MessageType type,
|
||||
flatbuffers::FlatBufferBuilder *fbb) {
|
||||
std::unique_lock<std::mutex> guard(write_mutex_);
|
||||
int64_t cookie = RayConfig::instance().ray_cookie();
|
||||
int64_t length = fbb ? fbb->GetSize() : 0;
|
||||
uint8_t *bytes = fbb ? fbb->GetBufferPointer() : nullptr;
|
||||
int64_t type_field = static_cast<int64_t>(type);
|
||||
auto io_error = ray::Status::IOError("[RayletClient] Connection closed unexpectedly.");
|
||||
int closed;
|
||||
closed = write_bytes(conn_, (uint8_t *)&cookie, sizeof(cookie));
|
||||
if (closed) return io_error;
|
||||
closed = write_bytes(conn_, (uint8_t *)&type_field, sizeof(type_field));
|
||||
if (closed) return io_error;
|
||||
closed = write_bytes(conn_, (uint8_t *)&length, sizeof(length));
|
||||
if (closed) return io_error;
|
||||
closed = write_bytes(conn_, bytes, length * sizeof(char));
|
||||
if (closed) return io_error;
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
ray::Status RayletConnection::AtomicRequestReply(
|
||||
MessageType request_type, MessageType reply_type,
|
||||
std::unique_ptr<uint8_t[]> &reply_message, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
auto status = WriteMessage(request_type, fbb);
|
||||
if (!status.ok()) return status;
|
||||
return ReadMessage(reply_type, reply_message);
|
||||
}
|
||||
|
||||
RayletClient::RayletClient(const std::string &raylet_socket, const WorkerID &worker_id,
|
||||
bool is_worker, const JobID &job_id, const Language &language,
|
||||
int port)
|
||||
: worker_id_(worker_id), is_worker_(is_worker), job_id_(job_id), language_(language) {
|
||||
// For C++14, we could use std::make_unique
|
||||
conn_ = std::unique_ptr<RayletConnection>(new RayletConnection(raylet_socket, -1, -1));
|
||||
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = ray::protocol::CreateRegisterClientRequest(
|
||||
fbb, is_worker, to_flatbuf(fbb, worker_id), getpid(), to_flatbuf(fbb, job_id),
|
||||
language, port);
|
||||
fbb.Finish(message);
|
||||
// Register the process ID with the raylet.
|
||||
// NOTE(swang): If raylet exits and we are registered as a worker, we will get killed.
|
||||
auto status = conn_->WriteMessage(MessageType::RegisterClientRequest, &fbb);
|
||||
RAY_CHECK_OK_PREPEND(status, "[RayletClient] Unable to register worker with raylet.");
|
||||
}
|
||||
|
||||
ray::Status RayletClient::SubmitTask(const ray::TaskSpecification &task_spec) {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = ray::protocol::CreateSubmitTaskRequest(
|
||||
fbb, fbb.CreateString(task_spec.Serialize()));
|
||||
fbb.Finish(message);
|
||||
return conn_->WriteMessage(MessageType::SubmitTask, &fbb);
|
||||
}
|
||||
|
||||
ray::Status RayletClient::GetTask(std::unique_ptr<ray::TaskSpecification> *task_spec) {
|
||||
std::unique_ptr<uint8_t[]> reply;
|
||||
// Receive a task from the raylet. This will block until the raylet
|
||||
// gives this client a task.
|
||||
auto status =
|
||||
conn_->AtomicRequestReply(MessageType::GetTask, MessageType::ExecuteTask, reply);
|
||||
if (!status.ok()) return status;
|
||||
// Parse the flatbuffer object.
|
||||
auto reply_message = flatbuffers::GetRoot<ray::protocol::GetTaskReply>(reply.get());
|
||||
// Set the resource IDs for this task.
|
||||
resource_ids_.clear();
|
||||
for (size_t i = 0; i < reply_message->fractional_resource_ids()->size(); ++i) {
|
||||
auto const &fractional_resource_ids =
|
||||
reply_message->fractional_resource_ids()->Get(i);
|
||||
auto &acquired_resources =
|
||||
resource_ids_[string_from_flatbuf(*fractional_resource_ids->resource_name())];
|
||||
|
||||
size_t num_resource_ids = fractional_resource_ids->resource_ids()->size();
|
||||
size_t num_resource_fractions = fractional_resource_ids->resource_fractions()->size();
|
||||
RAY_CHECK(num_resource_ids == num_resource_fractions);
|
||||
RAY_CHECK(num_resource_ids > 0);
|
||||
for (size_t j = 0; j < num_resource_ids; ++j) {
|
||||
int64_t resource_id = fractional_resource_ids->resource_ids()->Get(j);
|
||||
double resource_fraction = fractional_resource_ids->resource_fractions()->Get(j);
|
||||
if (num_resource_ids > 1) {
|
||||
int64_t whole_fraction = resource_fraction;
|
||||
RAY_CHECK(whole_fraction == resource_fraction);
|
||||
}
|
||||
acquired_resources.push_back(std::make_pair(resource_id, resource_fraction));
|
||||
}
|
||||
}
|
||||
|
||||
// Return the copy of the task spec and pass ownership to the caller.
|
||||
task_spec->reset(
|
||||
new ray::TaskSpecification(string_from_flatbuf(*reply_message->task_spec())));
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
ray::Status RayletClient::TaskDone() {
|
||||
return conn_->WriteMessage(MessageType::TaskDone);
|
||||
}
|
||||
|
||||
ray::Status RayletClient::FetchOrReconstruct(const std::vector<ObjectID> &object_ids,
|
||||
bool fetch_only,
|
||||
const TaskID ¤t_task_id) {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto object_ids_message = to_flatbuf(fbb, object_ids);
|
||||
auto message = ray::protocol::CreateFetchOrReconstruct(
|
||||
fbb, object_ids_message, fetch_only, to_flatbuf(fbb, current_task_id));
|
||||
fbb.Finish(message);
|
||||
auto status = conn_->WriteMessage(MessageType::FetchOrReconstruct, &fbb);
|
||||
return status;
|
||||
}
|
||||
|
||||
ray::Status RayletClient::NotifyUnblocked(const TaskID ¤t_task_id) {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message =
|
||||
ray::protocol::CreateNotifyUnblocked(fbb, to_flatbuf(fbb, current_task_id));
|
||||
fbb.Finish(message);
|
||||
return conn_->WriteMessage(MessageType::NotifyUnblocked, &fbb);
|
||||
}
|
||||
|
||||
ray::Status RayletClient::Wait(const std::vector<ObjectID> &object_ids, int num_returns,
|
||||
int64_t timeout_milliseconds, bool wait_local,
|
||||
const TaskID ¤t_task_id, WaitResultPair *result) {
|
||||
// Write request.
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = ray::protocol::CreateWaitRequest(
|
||||
fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds, wait_local,
|
||||
to_flatbuf(fbb, current_task_id));
|
||||
fbb.Finish(message);
|
||||
std::unique_ptr<uint8_t[]> reply;
|
||||
auto status = conn_->AtomicRequestReply(MessageType::WaitRequest,
|
||||
MessageType::WaitReply, reply, &fbb);
|
||||
if (!status.ok()) return status;
|
||||
// Parse the flatbuffer object.
|
||||
auto reply_message = flatbuffers::GetRoot<ray::protocol::WaitReply>(reply.get());
|
||||
auto found = reply_message->found();
|
||||
for (uint i = 0; i < found->size(); i++) {
|
||||
ObjectID object_id = ObjectID::FromBinary(found->Get(i)->str());
|
||||
result->first.push_back(object_id);
|
||||
}
|
||||
auto remaining = reply_message->remaining();
|
||||
for (uint i = 0; i < remaining->size(); i++) {
|
||||
ObjectID object_id = ObjectID::FromBinary(remaining->Get(i)->str());
|
||||
result->second.push_back(object_id);
|
||||
}
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string &type,
|
||||
const std::string &error_message, double timestamp) {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = ray::protocol::CreatePushErrorRequest(
|
||||
fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type),
|
||||
fbb.CreateString(error_message), timestamp);
|
||||
fbb.Finish(message);
|
||||
|
||||
return conn_->WriteMessage(MessageType::PushErrorRequest, &fbb);
|
||||
}
|
||||
|
||||
ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_events) {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = fbb.CreateString(profile_events.SerializeAsString());
|
||||
fbb.Finish(message);
|
||||
|
||||
auto status = conn_->WriteMessage(MessageType::PushProfileEventsRequest, &fbb);
|
||||
// Don't be too strict for profile errors. Just create logs and prevent it from crash.
|
||||
if (!status.ok()) {
|
||||
RAY_LOG(ERROR) << status.ToString()
|
||||
<< " [RayletClient] Failed to push profile events.";
|
||||
}
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
ray::Status RayletClient::FreeObjects(const std::vector<ray::ObjectID> &object_ids,
|
||||
bool local_only, bool delete_creating_tasks) {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = ray::protocol::CreateFreeObjectsRequest(
|
||||
fbb, local_only, delete_creating_tasks, to_flatbuf(fbb, object_ids));
|
||||
fbb.Finish(message);
|
||||
|
||||
auto status = conn_->WriteMessage(MessageType::FreeObjectsInObjectStoreRequest, &fbb);
|
||||
return status;
|
||||
}
|
||||
|
||||
ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id,
|
||||
ActorCheckpointID &checkpoint_id) {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message =
|
||||
ray::protocol::CreatePrepareActorCheckpointRequest(fbb, to_flatbuf(fbb, actor_id));
|
||||
fbb.Finish(message);
|
||||
|
||||
std::unique_ptr<uint8_t[]> reply;
|
||||
auto status =
|
||||
conn_->AtomicRequestReply(MessageType::PrepareActorCheckpointRequest,
|
||||
MessageType::PrepareActorCheckpointReply, reply, &fbb);
|
||||
if (!status.ok()) return status;
|
||||
auto reply_message =
|
||||
flatbuffers::GetRoot<ray::protocol::PrepareActorCheckpointReply>(reply.get());
|
||||
checkpoint_id = ActorCheckpointID::FromBinary(reply_message->checkpoint_id()->str());
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
ray::Status RayletClient::NotifyActorResumedFromCheckpoint(
|
||||
const ActorID &actor_id, const ActorCheckpointID &checkpoint_id) {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = ray::protocol::CreateNotifyActorResumedFromCheckpoint(
|
||||
fbb, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, checkpoint_id));
|
||||
fbb.Finish(message);
|
||||
|
||||
return conn_->WriteMessage(MessageType::NotifyActorResumedFromCheckpoint, &fbb);
|
||||
}
|
||||
|
||||
ray::Status RayletClient::SetResource(const std::string &resource_name,
|
||||
const double capacity,
|
||||
const ray::ClientID &client_Id) {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = ray::protocol::CreateSetResourceRequest(
|
||||
fbb, fbb.CreateString(resource_name), capacity, to_flatbuf(fbb, client_Id));
|
||||
fbb.Finish(message);
|
||||
return conn_->WriteMessage(MessageType::SetResourceRequest, &fbb);
|
||||
}
|
||||
@@ -1,26 +1,18 @@
|
||||
#ifndef RAY_RPC_RAYLET_CLIENT_H
|
||||
#define RAY_RPC_RAYLET_CLIENT_H
|
||||
#ifndef RAYLET_CLIENT_H
|
||||
#define RAYLET_CLIENT_H
|
||||
|
||||
#include <ray/common/ray_config.h>
|
||||
#include <ray/protobuf/gcs.pb.h>
|
||||
#include <unistd.h>
|
||||
#include <future>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/common/task/task_spec.h"
|
||||
#include "src/ray/common/status.h"
|
||||
#include "src/ray/protobuf/raylet.grpc.pb.h"
|
||||
#include "src/ray/protobuf/raylet.pb.h"
|
||||
#include "src/ray/rpc/client_call.h"
|
||||
|
||||
using ray::ActorCheckpointID;
|
||||
using ray::ActorID;
|
||||
using ray::ClientID;
|
||||
using ray::JobID;
|
||||
using ray::ObjectID;
|
||||
using ray::TaskID;
|
||||
@@ -28,39 +20,65 @@ using ray::WorkerID;
|
||||
|
||||
using ray::Language;
|
||||
using ray::rpc::ProfileTableData;
|
||||
using WaitResultPair = std::pair<std::vector<ObjectID>, std::vector<ObjectID>>;
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
using MessageType = ray::protocol::MessageType;
|
||||
using ResourceMappingType =
|
||||
std::unordered_map<std::string, std::vector<std::pair<int64_t, double>>>;
|
||||
using WaitResultPair = std::pair<std::vector<ObjectID>, std::vector<ObjectID>>;
|
||||
|
||||
class RayletConnection {
|
||||
public:
|
||||
/// Connect to the raylet.
|
||||
///
|
||||
/// \param raylet_socket The name of the socket to use to connect to the raylet.
|
||||
/// \param worker_id A unique ID to represent the worker.
|
||||
/// \param is_worker Whether this client is a worker. If it is a worker, an
|
||||
/// additional message will be sent to register as one.
|
||||
/// \param job_id The ID of the driver. This is non-nil if the client is a
|
||||
/// driver.
|
||||
/// \return The connection information.
|
||||
RayletConnection(const std::string &raylet_socket, int num_retries, int64_t timeout);
|
||||
|
||||
~RayletConnection() { close(conn_); }
|
||||
/// Notify the raylet that this client is disconnecting gracefully. This
|
||||
/// is used by actors to exit gracefully so that the raylet doesn't
|
||||
/// propagate an error message to the driver.
|
||||
///
|
||||
/// \return ray::Status.
|
||||
ray::Status Disconnect();
|
||||
ray::Status ReadMessage(MessageType type, std::unique_ptr<uint8_t[]> &message);
|
||||
ray::Status WriteMessage(MessageType type,
|
||||
flatbuffers::FlatBufferBuilder *fbb = nullptr);
|
||||
ray::Status AtomicRequestReply(MessageType request_type, MessageType reply_type,
|
||||
std::unique_ptr<uint8_t[]> &reply_message,
|
||||
flatbuffers::FlatBufferBuilder *fbb = nullptr);
|
||||
|
||||
private:
|
||||
/// File descriptor of the Unix domain socket that connects to raylet.
|
||||
int conn_;
|
||||
/// A mutex to protect stateful operations of the raylet client.
|
||||
std::mutex mutex_;
|
||||
/// A mutex to protect write operations of the raylet client.
|
||||
std::mutex write_mutex_;
|
||||
};
|
||||
|
||||
/// Client used for communicating with the raylet.
|
||||
class RayletClient {
|
||||
public:
|
||||
/// Constructor for the raylet client.
|
||||
/// TODO(jzh): At present, client call manager and reply handler service are generated
|
||||
/// in raylet client. Instead, we should add parameters to the constructor and pass them
|
||||
/// in. Change them as input parameters once we changed the worker into a server.
|
||||
/// Connect to the raylet.
|
||||
///
|
||||
/// \param[in] raylet_socket Unix domain socket of the raylet server.
|
||||
/// \param[in] worker_id The worker id.
|
||||
/// \param[in] is_worker Indicates whether a worker or a driver.
|
||||
/// \param[in] job_id The job id that this raylet client belongs to.
|
||||
/// \param[in] language The language type, python or java.
|
||||
/// \param[in] port The listening port of the worker server, -1 means that the worker
|
||||
/// does not have a server.
|
||||
/// \param raylet_socket The name of the socket to use to connect to the raylet.
|
||||
/// \param worker_id A unique ID to represent the worker.
|
||||
/// \param is_worker Whether this client is a worker. If it is a worker, an
|
||||
/// additional message will be sent to register as one.
|
||||
/// \param job_id The ID of the driver. This is non-nil if the client is a driver.
|
||||
/// \return The connection information.
|
||||
RayletClient(const std::string &raylet_socket, const WorkerID &worker_id,
|
||||
bool is_worker, const JobID &job_id, const Language &language,
|
||||
int port = -1);
|
||||
|
||||
~RayletClient();
|
||||
ray::Status Disconnect() { return conn_->Disconnect(); };
|
||||
|
||||
/// Send disconnect request to local raylet.
|
||||
ray::Status Disconnect();
|
||||
|
||||
/// Submit a task to the local raylet.
|
||||
/// Submit a task using the raylet code path.
|
||||
///
|
||||
/// \param The task specification.
|
||||
/// \return ray::Status.
|
||||
@@ -160,7 +178,7 @@ class RayletClient {
|
||||
|
||||
Language GetLanguage() const { return language_; }
|
||||
|
||||
WorkerID GetWorkerId() const { return worker_id_; }
|
||||
WorkerID GetWorkerID() const { return worker_id_; }
|
||||
|
||||
JobID GetJobID() const { return job_id_; }
|
||||
|
||||
@@ -169,53 +187,16 @@ class RayletClient {
|
||||
const ResourceMappingType &GetResourceIDs() const { return resource_ids_; }
|
||||
|
||||
private:
|
||||
/// Try to register client in raylet, we would retry serveral time to
|
||||
/// reconnect if failed. We need this because raylet client may start before raylet
|
||||
/// server.
|
||||
///
|
||||
/// \param times Number of times to retry.
|
||||
void TryRegisterClient(int retry_times);
|
||||
|
||||
ray::Status RegisterClient();
|
||||
|
||||
/// Send heartbeat requests to the raylet server.
|
||||
void Heartbeat();
|
||||
/// Id of the worker to which this raylet client belongs.
|
||||
const WorkerID worker_id_;
|
||||
/// Indicates whether this worker is a driver worker.
|
||||
/// Driver is treated as a special worker.
|
||||
const bool is_worker_;
|
||||
const JobID job_id_;
|
||||
const Language language_;
|
||||
const int port_;
|
||||
/// A map from resource name to the resource IDs that are currently reserved
|
||||
/// for this worker. Each pair consists of the resource ID and the fraction
|
||||
/// of that resource allocated for this worker.
|
||||
ResourceMappingType resource_ids_;
|
||||
|
||||
/// The gRPC-generated stub.
|
||||
std::unique_ptr<RayletService::Stub> stub_;
|
||||
|
||||
/// Service for handling reply.
|
||||
boost::asio::io_service main_service_;
|
||||
|
||||
/// Asio work for main service.
|
||||
boost::asio::io_service::work work_;
|
||||
|
||||
/// The `ClientCallManager` used for managing requests.
|
||||
ClientCallManager client_call_manager_;
|
||||
|
||||
/// The thread used to handle reply.
|
||||
std::thread rpc_thread_;
|
||||
|
||||
/// Heartbeat timer.
|
||||
boost::asio::deadline_timer heartbeat_timer_;
|
||||
|
||||
/// Indicates whether the connection has been closed.
|
||||
bool is_connected_;
|
||||
/// The connection to the raylet server.
|
||||
std::unique_ptr<RayletConnection> conn_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RPC_RAYLET_CLIENT_H
|
||||
#endif
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include <boost/asio.hpp>
|
||||
|
||||
#include "ray/raylet/format/node_manager_generated.h"
|
||||
#include "ray/raylet/reconstruction_policy.h"
|
||||
|
||||
#include "ray/object_manager/object_directory.h"
|
||||
@@ -417,7 +418,6 @@ TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) {
|
||||
auto task_reconstruction_data = std::make_shared<TaskReconstructionData>();
|
||||
task_reconstruction_data->set_node_manager_id(ClientID::FromRandom().Binary());
|
||||
task_reconstruction_data->set_num_reconstructions(0);
|
||||
|
||||
RAY_CHECK_OK(
|
||||
mock_gcs_.AppendAt(JobID::Nil(), task_id, task_reconstruction_data, nullptr,
|
||||
/*failure_callback=*/
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
// clang-format off
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/ray_config.h"
|
||||
#include "ray/common/task/task.h"
|
||||
#include "ray/object_manager/object_manager.h"
|
||||
#include "ray/raylet/reconstruction_policy.h"
|
||||
|
||||
+28
-36
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <boost/bind.hpp>
|
||||
|
||||
#include "ray/raylet/format/node_manager_generated.h"
|
||||
#include "ray/raylet/raylet.h"
|
||||
|
||||
namespace ray {
|
||||
@@ -10,27 +11,25 @@ namespace raylet {
|
||||
|
||||
/// A constructor responsible for initializing the state of a worker.
|
||||
Worker::Worker(const WorkerID &worker_id, pid_t pid, const Language &language, int port,
|
||||
rpc::ClientCallManager &client_call_manager, bool is_worker)
|
||||
std::shared_ptr<LocalClientConnection> connection,
|
||||
rpc::ClientCallManager &client_call_manager)
|
||||
: worker_id_(worker_id),
|
||||
pid_(pid),
|
||||
port_(port),
|
||||
language_(language),
|
||||
port_(port),
|
||||
connection_(connection),
|
||||
dead_(false),
|
||||
blocked_(false),
|
||||
num_missed_heartbeats_(0),
|
||||
is_being_killed_(false),
|
||||
client_call_manager_(client_call_manager),
|
||||
is_worker_(is_worker) {
|
||||
client_call_manager_(client_call_manager) {
|
||||
if (port_ > 0) {
|
||||
rpc_client_ = std::unique_ptr<rpc::WorkerTaskClient>(
|
||||
new rpc::WorkerTaskClient("127.0.0.1", port_, client_call_manager_));
|
||||
}
|
||||
}
|
||||
|
||||
void Worker::MarkAsBeingKilled() { is_being_killed_ = true; }
|
||||
void Worker::MarkDead() { dead_ = true; }
|
||||
|
||||
bool Worker::IsBeingKilled() const { return is_being_killed_; }
|
||||
|
||||
bool Worker::IsWorker() const { return is_worker_; }
|
||||
bool Worker::IsDead() const { return dead_; }
|
||||
|
||||
void Worker::MarkBlocked() { blocked_ = true; }
|
||||
|
||||
@@ -44,8 +43,6 @@ pid_t Worker::Pid() const { return pid_; }
|
||||
|
||||
Language Worker::GetLanguage() const { return language_; }
|
||||
|
||||
const WorkerID &Worker::GetWorkerId() const { return worker_id_; }
|
||||
|
||||
int Worker::Port() const { return port_; }
|
||||
|
||||
void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id; }
|
||||
@@ -79,6 +76,10 @@ void Worker::AssignActorId(const ActorID &actor_id) {
|
||||
|
||||
const ActorID &Worker::GetActorId() const { return actor_id_; }
|
||||
|
||||
const std::shared_ptr<LocalClientConnection> Worker::Connection() const {
|
||||
return connection_;
|
||||
}
|
||||
|
||||
const ResourceIdSet &Worker::GetLifetimeResourceIds() const {
|
||||
return lifetime_resource_ids_;
|
||||
}
|
||||
@@ -112,16 +113,10 @@ void Worker::AcquireTaskCpuResources(const ResourceIdSet &cpu_resources) {
|
||||
task_resource_ids_.Release(cpu_resources);
|
||||
}
|
||||
|
||||
void Worker::SetGetTaskReplyAndCallback(
|
||||
rpc::GetTaskReply *reply, const rpc::SendReplyCallback &&send_reply_callback) {
|
||||
RAY_CHECK(reply_ == nullptr && send_reply_callback_ == nullptr);
|
||||
reply_ = reply;
|
||||
send_reply_callback_ = std::move(send_reply_callback);
|
||||
}
|
||||
|
||||
bool Worker::UsePush() const { return rpc_client_ != nullptr; }
|
||||
|
||||
void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set) {
|
||||
void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set,
|
||||
const std::function<void(Status)> finish_assign_callback) {
|
||||
const TaskSpecification &spec = task.GetTaskSpecification();
|
||||
if (rpc_client_ != nullptr) {
|
||||
// Use push mode.
|
||||
@@ -131,16 +126,15 @@ void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set)
|
||||
task.GetTaskSpecification().GetMessage());
|
||||
request.mutable_task()->mutable_task_execution_spec()->CopyFrom(
|
||||
task.GetTaskExecutionSpec().GetMessage());
|
||||
for (const auto &e : resource_id_set.ToProtobuf()) {
|
||||
auto resource = request.add_resource_ids();
|
||||
*resource = e;
|
||||
}
|
||||
request.set_resource_ids(resource_id_set.Serialize());
|
||||
|
||||
auto status = rpc_client_->AssignTask(
|
||||
request, [](Status status, const rpc::AssignTaskReply &reply) {
|
||||
// Worker has finished this task. There's nothing to do here
|
||||
// and assigning new task will be done when raylet receives
|
||||
// `TaskDone` message.
|
||||
});
|
||||
finish_assign_callback(status);
|
||||
if (!status.ok()) {
|
||||
RAY_LOG(ERROR) << "Failed to assign task " << task.GetTaskSpecification().TaskId()
|
||||
<< " to worker " << worker_id_;
|
||||
@@ -151,19 +145,17 @@ void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set)
|
||||
} else {
|
||||
// Use pull mode. This corresponds to existing python/java workers that haven't been
|
||||
// migrated to core worker architecture.
|
||||
RAY_CHECK(reply_ != nullptr && send_reply_callback_ != nullptr);
|
||||
reply_->set_task_spec(task.GetTaskSpecification().Serialize());
|
||||
for (const auto &e : resource_id_set.ToProtobuf()) {
|
||||
auto resource = reply_->add_fractional_resource_ids();
|
||||
*resource = e;
|
||||
}
|
||||
send_reply_callback_(Status::OK(), nullptr, nullptr);
|
||||
reply_ = nullptr;
|
||||
send_reply_callback_ = nullptr;
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto resource_id_set_flatbuf = resource_id_set.ToFlatbuf(fbb);
|
||||
|
||||
auto message =
|
||||
protocol::CreateGetTaskReply(fbb, fbb.CreateString(spec.Serialize()),
|
||||
fbb.CreateVector(resource_id_set_flatbuf));
|
||||
fbb.Finish(message);
|
||||
Connection()->WriteMessageAsync(
|
||||
static_cast<int64_t>(protocol::MessageType::ExecuteTask), fbb.GetSize(),
|
||||
fbb.GetBufferPointer(), finish_assign_callback);
|
||||
}
|
||||
// The status will be cleared when the worker dies.
|
||||
AssignTaskId(spec.TaskId());
|
||||
AssignJobId(spec.JobId());
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
+15
-33
@@ -3,15 +3,12 @@
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ray/common/client_connection.h"
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/task/scheduling_resources.h"
|
||||
#include "ray/common/task/task.h"
|
||||
#include "ray/common/task/task_common.h"
|
||||
#include "ray/protobuf/common.pb.h"
|
||||
#include "ray/rpc/worker/worker_client.h"
|
||||
#include "src/ray/protobuf/gcs.pb.h"
|
||||
#include "src/ray/protobuf/raylet.pb.h"
|
||||
#include "src/ray/rpc/server_call.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
@@ -24,12 +21,12 @@ class Worker {
|
||||
public:
|
||||
/// A constructor that initializes a worker object.
|
||||
Worker(const WorkerID &worker_id, pid_t pid, const Language &language, int port,
|
||||
rpc::ClientCallManager &client_call_manager, bool is_worker = true);
|
||||
std::shared_ptr<LocalClientConnection> connection,
|
||||
rpc::ClientCallManager &client_call_manager);
|
||||
/// A destructor responsible for freeing all worker state.
|
||||
~Worker() {}
|
||||
void MarkAsBeingKilled();
|
||||
bool IsBeingKilled() const;
|
||||
bool IsWorker() const;
|
||||
void MarkDead();
|
||||
bool IsDead() const;
|
||||
void MarkBlocked();
|
||||
void MarkUnblocked();
|
||||
bool IsBlocked() const;
|
||||
@@ -38,7 +35,6 @@ class Worker {
|
||||
/// Return the worker's PID.
|
||||
pid_t Pid() const;
|
||||
Language GetLanguage() const;
|
||||
const WorkerID &GetWorkerId() const;
|
||||
int Port() const;
|
||||
void AssignTaskId(const TaskID &task_id);
|
||||
const TaskID &GetAssignedTaskId() const;
|
||||
@@ -49,6 +45,7 @@ class Worker {
|
||||
const JobID &GetAssignedJobId() const;
|
||||
void AssignActorId(const ActorID &actor_id);
|
||||
const ActorID &GetActorId() const;
|
||||
const std::shared_ptr<LocalClientConnection> Connection() const;
|
||||
|
||||
const ResourceIdSet &GetLifetimeResourceIds() const;
|
||||
void SetLifetimeResourceIds(ResourceIdSet &resource_ids);
|
||||
@@ -60,34 +57,30 @@ class Worker {
|
||||
ResourceIdSet ReleaseTaskCpuResources();
|
||||
void AcquireTaskCpuResources(const ResourceIdSet &cpu_resources);
|
||||
|
||||
int TickHeartbeatTimer() { return ++num_missed_heartbeats_; }
|
||||
void ClearHeartbeat() { num_missed_heartbeats_ = 0; }
|
||||
|
||||
/// When receiving a `GetTask` request from worker, this function should be called to
|
||||
/// pass the reply and callback of the request to the worker. Later, when we actually
|
||||
/// assign a task to the worker, the reply will be filled and the callback will be
|
||||
/// called.
|
||||
void SetGetTaskReplyAndCallback(rpc::GetTaskReply *reply,
|
||||
const rpc::SendReplyCallback &&send_reply_callback);
|
||||
|
||||
bool UsePush() const;
|
||||
void AssignTask(const Task &task, const ResourceIdSet &resource_id_set);
|
||||
void AssignTask(const Task &task, const ResourceIdSet &resource_id_set,
|
||||
const std::function<void(Status)> finish_assign_callback);
|
||||
|
||||
private:
|
||||
/// The worker's ID.
|
||||
WorkerID worker_id_;
|
||||
/// The worker's PID.
|
||||
pid_t pid_;
|
||||
/// The worker port.
|
||||
int port_;
|
||||
/// The language type of this worker.
|
||||
Language language_;
|
||||
/// Port that this worker listens on.
|
||||
/// If port <= 0, this indicates that the worker will not listen to a port.
|
||||
int port_;
|
||||
/// Connection state of a worker.
|
||||
std::shared_ptr<LocalClientConnection> connection_;
|
||||
/// The worker's currently assigned task.
|
||||
TaskID assigned_task_id_;
|
||||
/// Job ID for the worker's current assigned task.
|
||||
JobID assigned_job_id_;
|
||||
/// The worker's actor ID. If this is nil, then the worker is not an actor.
|
||||
ActorID actor_id_;
|
||||
/// Whether the worker is dead.
|
||||
bool dead_;
|
||||
/// Whether the worker is blocked. Workers become blocked in a `ray.get`, if
|
||||
/// they require a data dependency while executing a task.
|
||||
bool blocked_;
|
||||
@@ -98,22 +91,11 @@ class Worker {
|
||||
// of a task.
|
||||
ResourceIdSet task_resource_ids_;
|
||||
std::unordered_set<TaskID> blocked_task_ids_;
|
||||
/// How many heartbeats have been missed for this worker.
|
||||
int num_missed_heartbeats_;
|
||||
/// Indicates we have sent kill signal to the worker if it's true. We cannot treat the
|
||||
/// worker process as really dead until we lost the heartbeats from the worker.
|
||||
bool is_being_killed_;
|
||||
/// The `ClientCallManager` object that is shared by `WorkerTaskClient` from all
|
||||
/// workers.
|
||||
rpc::ClientCallManager &client_call_manager_;
|
||||
/// Indicates whether this is a worker or a driver.
|
||||
bool is_worker_;
|
||||
/// The rpc client to send tasks to this worker.
|
||||
std::unique_ptr<rpc::WorkerTaskClient> rpc_client_;
|
||||
/// Reply of the `GetTask` request.
|
||||
rpc::GetTaskReply *reply_ = nullptr;
|
||||
/// Callback of the `GetTask` request.
|
||||
rpc::SendReplyCallback send_reply_callback_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
@@ -12,6 +12,29 @@
|
||||
#include "ray/util/logging.h"
|
||||
#include "ray/util/util.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// A helper function to get a worker from a list.
|
||||
std::shared_ptr<ray::raylet::Worker> GetWorker(
|
||||
const std::unordered_set<std::shared_ptr<ray::raylet::Worker>> &worker_pool,
|
||||
const std::shared_ptr<ray::LocalClientConnection> &connection) {
|
||||
for (auto it = worker_pool.begin(); it != worker_pool.end(); it++) {
|
||||
if ((*it)->Connection() == connection) {
|
||||
return (*it);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// A helper function to remove a worker from a list. Returns true if the worker
|
||||
// was found and removed.
|
||||
bool RemoveWorker(std::unordered_set<std::shared_ptr<ray::raylet::Worker>> &worker_pool,
|
||||
const std::shared_ptr<ray::raylet::Worker> &worker) {
|
||||
return worker_pool.erase(worker) > 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace ray {
|
||||
|
||||
namespace raylet {
|
||||
@@ -50,8 +73,8 @@ WorkerPool::~WorkerPool() {
|
||||
for (const auto &entry : states_by_lang_) {
|
||||
// Kill all registered workers. NOTE(swang): This assumes that the registered
|
||||
// workers were started by the pool.
|
||||
for (const auto &worker_pair : entry.second.registered_workers) {
|
||||
pids_to_kill.insert(worker_pair.second->Pid());
|
||||
for (const auto &worker : entry.second.registered_workers) {
|
||||
pids_to_kill.insert(worker->Pid());
|
||||
}
|
||||
// Kill all the workers that have been started but not registered.
|
||||
for (const auto &starting_worker : entry.second.starting_worker_processes) {
|
||||
@@ -166,8 +189,7 @@ pid_t WorkerPool::StartProcess(const std::vector<std::string> &worker_command_ar
|
||||
return 0;
|
||||
}
|
||||
|
||||
Status WorkerPool::RegisterWorker(const WorkerID &worker_id,
|
||||
const std::shared_ptr<Worker> &worker) {
|
||||
Status WorkerPool::RegisterWorker(const std::shared_ptr<Worker> &worker) {
|
||||
const auto pid = worker->Pid();
|
||||
const auto port = worker->Port();
|
||||
RAY_LOG(DEBUG) << "Registering worker with pid " << pid << ", port: " << port;
|
||||
@@ -183,35 +205,34 @@ Status WorkerPool::RegisterWorker(const WorkerID &worker_id,
|
||||
state.starting_worker_processes.erase(it);
|
||||
}
|
||||
|
||||
state.registered_workers.emplace(worker_id, std::move(worker));
|
||||
state.registered_workers.emplace(std::move(worker));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status WorkerPool::RegisterDriver(const WorkerID &driver_id,
|
||||
const std::shared_ptr<Worker> &driver) {
|
||||
Status WorkerPool::RegisterDriver(const std::shared_ptr<Worker> &driver) {
|
||||
RAY_CHECK(!driver->GetAssignedTaskId().IsNil());
|
||||
auto &state = GetStateForLanguage(driver->GetLanguage());
|
||||
state.registered_drivers.emplace(driver_id, std::move(driver));
|
||||
state.registered_drivers.insert(std::move(driver));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<Worker> WorkerPool::GetRegisteredWorker(const WorkerID &worker_id) const {
|
||||
std::shared_ptr<Worker> WorkerPool::GetRegisteredWorker(
|
||||
const std::shared_ptr<LocalClientConnection> &connection) const {
|
||||
for (const auto &entry : states_by_lang_) {
|
||||
auto ®istered_workers = entry.second.registered_workers;
|
||||
auto it = registered_workers.find(worker_id);
|
||||
if (it != registered_workers.end()) {
|
||||
return it->second;
|
||||
auto worker = GetWorker(entry.second.registered_workers, connection);
|
||||
if (worker != nullptr) {
|
||||
return worker;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<Worker> WorkerPool::GetRegisteredDriver(const WorkerID &worker_id) const {
|
||||
std::shared_ptr<Worker> WorkerPool::GetRegisteredDriver(
|
||||
const std::shared_ptr<LocalClientConnection> &connection) const {
|
||||
for (const auto &entry : states_by_lang_) {
|
||||
auto ®istered_drivers = entry.second.registered_drivers;
|
||||
auto it = registered_drivers.find(worker_id);
|
||||
if (it != registered_drivers.end()) {
|
||||
return it->second;
|
||||
auto driver = GetWorker(entry.second.registered_drivers, connection);
|
||||
if (driver != nullptr) {
|
||||
return driver;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
@@ -295,20 +316,18 @@ std::shared_ptr<Worker> WorkerPool::PopWorker(const TaskSpecification &task_spec
|
||||
|
||||
bool WorkerPool::DisconnectWorker(const std::shared_ptr<Worker> &worker) {
|
||||
auto &state = GetStateForLanguage(worker->GetLanguage());
|
||||
RAY_CHECK(state.registered_workers.erase(worker->GetWorkerId()));
|
||||
RAY_CHECK(RemoveWorker(state.registered_workers, worker));
|
||||
|
||||
stats::CurrentWorker().Record(
|
||||
0, {{stats::LanguageKey, Language_Name(worker->GetLanguage())},
|
||||
{stats::WorkerPidKey, std::to_string(worker->Pid())}});
|
||||
|
||||
// Indicates that we disconnect a idle worker successfully.
|
||||
return (state.idle.erase(worker) > 0);
|
||||
return RemoveWorker(state.idle, worker);
|
||||
}
|
||||
|
||||
void WorkerPool::DisconnectDriver(const std::shared_ptr<Worker> &driver) {
|
||||
auto &state = GetStateForLanguage(driver->GetLanguage());
|
||||
RAY_CHECK(state.registered_drivers.erase(driver->GetWorkerId()));
|
||||
|
||||
RAY_CHECK(RemoveWorker(state.registered_drivers, driver));
|
||||
stats::CurrentDriver().Record(
|
||||
0, {{stats::LanguageKey, Language_Name(driver->GetLanguage())},
|
||||
{stats::WorkerPidKey, std::to_string(driver->Pid())}});
|
||||
@@ -325,8 +344,7 @@ std::vector<std::shared_ptr<Worker>> WorkerPool::GetWorkersRunningTasksForJob(
|
||||
std::vector<std::shared_ptr<Worker>> workers;
|
||||
|
||||
for (const auto &entry : states_by_lang_) {
|
||||
for (const auto &worker_pair : entry.second.registered_workers) {
|
||||
auto &worker = worker_pair.second;
|
||||
for (const auto &worker : entry.second.registered_workers) {
|
||||
if (worker->GetAssignedJobId() == job_id) {
|
||||
workers.push_back(worker);
|
||||
}
|
||||
@@ -381,16 +399,14 @@ std::string WorkerPool::DebugString() const {
|
||||
void WorkerPool::RecordMetrics() const {
|
||||
for (const auto &entry : states_by_lang_) {
|
||||
// Record worker.
|
||||
for (auto worker_pair : entry.second.registered_workers) {
|
||||
auto &worker = worker_pair.second;
|
||||
for (auto worker : entry.second.registered_workers) {
|
||||
stats::CurrentWorker().Record(
|
||||
worker->Pid(), {{stats::LanguageKey, Language_Name(worker->GetLanguage())},
|
||||
{stats::WorkerPidKey, std::to_string(worker->Pid())}});
|
||||
}
|
||||
|
||||
// Record driver.
|
||||
for (auto driver_pair : entry.second.registered_drivers) {
|
||||
auto &driver = driver_pair.second;
|
||||
for (auto driver : entry.second.registered_drivers) {
|
||||
stats::CurrentDriver().Record(
|
||||
driver->Pid(), {{stats::LanguageKey, Language_Name(driver->GetLanguage())},
|
||||
{stats::WorkerPidKey, std::to_string(driver->Pid())}});
|
||||
@@ -398,38 +414,6 @@ void WorkerPool::RecordMetrics() const {
|
||||
}
|
||||
}
|
||||
|
||||
void WorkerPool::TickHeartbeatTimer(int max_missed_heartbeats,
|
||||
std::vector<std::shared_ptr<Worker>> *dead_workers) {
|
||||
for (const auto &entry : states_by_lang_) {
|
||||
// Worker heartbeat.
|
||||
for (const auto &worker_pair : entry.second.registered_workers) {
|
||||
auto &worker = worker_pair.second;
|
||||
if (worker->TickHeartbeatTimer() >= max_missed_heartbeats) {
|
||||
dead_workers->emplace_back(worker);
|
||||
}
|
||||
}
|
||||
|
||||
// Driver heartbeat.
|
||||
for (const auto &driver_pair : entry.second.registered_drivers) {
|
||||
auto &driver = driver_pair.second;
|
||||
if (driver->TickHeartbeatTimer() >= max_missed_heartbeats) {
|
||||
dead_workers->emplace_back(driver);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Worker> WorkerPool::GetWorker(const WorkerID &worker_id) {
|
||||
auto worker = GetRegisteredWorker(worker_id);
|
||||
if (!worker) {
|
||||
worker = GetRegisteredDriver(worker_id);
|
||||
if (!worker) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
return worker;
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -52,28 +52,29 @@ class WorkerPool {
|
||||
///
|
||||
/// \param The Worker to be registered.
|
||||
/// \return If the registration is successful.
|
||||
Status RegisterWorker(const WorkerID &worker_id, const std::shared_ptr<Worker> &worker);
|
||||
Status RegisterWorker(const std::shared_ptr<Worker> &worker);
|
||||
|
||||
/// Register a new driver.
|
||||
/// Driver is a treated as a special worker, so use WorkerID as key here.
|
||||
///
|
||||
/// \param The driver to be registered.
|
||||
/// \return If the registration is successful.
|
||||
Status RegisterDriver(const WorkerID &driver_id, const std::shared_ptr<Worker> &worker);
|
||||
Status RegisterDriver(const std::shared_ptr<Worker> &worker);
|
||||
|
||||
/// Get the client connection's registered worker.
|
||||
///
|
||||
/// \param The client connection owned by a registered worker.
|
||||
/// \return The Worker that owns the given client connection. Returns nullptr
|
||||
/// if the client has not registered a worker yet.
|
||||
std::shared_ptr<Worker> GetRegisteredWorker(const WorkerID &worker_id) const;
|
||||
std::shared_ptr<Worker> GetRegisteredWorker(
|
||||
const std::shared_ptr<LocalClientConnection> &connection) const;
|
||||
|
||||
/// Get the client connection's registered driver.
|
||||
///
|
||||
/// \param The client connection owned by a registered driver.
|
||||
/// \return The Worker that owns the given client connection. Returns nullptr
|
||||
/// if the client has not registered a driver.
|
||||
std::shared_ptr<Worker> GetRegisteredDriver(const WorkerID &driver_id) const;
|
||||
std::shared_ptr<Worker> GetRegisteredDriver(
|
||||
const std::shared_ptr<LocalClientConnection> &connection) const;
|
||||
|
||||
/// Disconnect a registered worker.
|
||||
///
|
||||
@@ -130,21 +131,6 @@ class WorkerPool {
|
||||
/// Record metrics.
|
||||
void RecordMetrics() const;
|
||||
|
||||
/// Tick the heartbeat timer and get the workers that have timed out.
|
||||
/// A worker which has missed `max_missed_heartbeats` times would be treated as a
|
||||
/// dead process or the network to it has been down.
|
||||
///
|
||||
/// \param[in] max_missed_heartbeats The maximum number of heartbeats that can be
|
||||
/// missed before a worker times out.
|
||||
/// \param[out] dead_workers Workers whose processes have been dead.
|
||||
void TickHeartbeatTimer(int max_missed_heartbeats,
|
||||
std::vector<std::shared_ptr<Worker>> *dead_workers);
|
||||
|
||||
/// Return the pointer to the worker according to the worker id.
|
||||
///
|
||||
/// \param worker_id The worker id.
|
||||
std::shared_ptr<Worker> GetWorker(const WorkerID &worker_id);
|
||||
|
||||
protected:
|
||||
/// Asynchronously start a new worker process. Once the worker process has
|
||||
/// registered with an external server, the process should create and
|
||||
@@ -182,9 +168,9 @@ class WorkerPool {
|
||||
std::unordered_map<ActorID, std::shared_ptr<Worker>> idle_actor;
|
||||
/// All workers that have registered and are still connected, including both
|
||||
/// idle and executing.
|
||||
std::unordered_map<WorkerID, std::shared_ptr<Worker>> registered_workers;
|
||||
std::unordered_set<std::shared_ptr<Worker>> registered_workers;
|
||||
/// All drivers that have registered and are still connected.
|
||||
std::unordered_map<WorkerID, std::shared_ptr<Worker>> registered_drivers;
|
||||
std::unordered_set<std::shared_ptr<Worker>> registered_drivers;
|
||||
/// A map from the pids of starting worker processes
|
||||
/// to the number of their unregistered workers.
|
||||
std::unordered_map<pid_t, int> starting_worker_processes;
|
||||
|
||||
@@ -71,9 +71,19 @@ class WorkerPoolTest : public ::testing::Test {
|
||||
|
||||
std::shared_ptr<Worker> CreateWorker(pid_t pid,
|
||||
const Language &language = Language::PYTHON) {
|
||||
WorkerID worker_id = WorkerID::FromRandom();
|
||||
return std::shared_ptr<Worker>(new Worker(
|
||||
worker_id, pid, language, /* listening port */ -1, client_call_manager_));
|
||||
std::function<void(LocalClientConnection &)> client_handler =
|
||||
[this](LocalClientConnection &client) { HandleNewClient(client); };
|
||||
std::function<void(std::shared_ptr<LocalClientConnection>, int64_t, const uint8_t *)>
|
||||
message_handler = [this](std::shared_ptr<LocalClientConnection> client,
|
||||
int64_t message_type, const uint8_t *message) {
|
||||
HandleMessage(client, message_type, message);
|
||||
};
|
||||
boost::asio::local::stream_protocol::socket socket(io_service_);
|
||||
auto client =
|
||||
LocalClientConnection::Create(client_handler, message_handler, std::move(socket),
|
||||
"worker", {}, error_message_type_);
|
||||
return std::shared_ptr<Worker>(new Worker(WorkerID::FromRandom(), pid, language, -1,
|
||||
client, client_call_manager_));
|
||||
}
|
||||
|
||||
void SetWorkerCommands(const WorkerCommandMap &worker_commands) {
|
||||
@@ -86,6 +96,10 @@ class WorkerPoolTest : public ::testing::Test {
|
||||
boost::asio::io_service io_service_;
|
||||
int64_t error_message_type_;
|
||||
rpc::ClientCallManager client_call_manager_;
|
||||
|
||||
private:
|
||||
void HandleNewClient(LocalClientConnection &){};
|
||||
void HandleMessage(std::shared_ptr<LocalClientConnection>, int64_t, const uint8_t *){};
|
||||
};
|
||||
|
||||
static inline TaskSpecification ExampleTaskSpec(
|
||||
@@ -117,23 +131,21 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) {
|
||||
workers.push_back(CreateWorker(pid));
|
||||
}
|
||||
for (const auto &worker : workers) {
|
||||
WorkerID worker_id = worker->GetWorkerId();
|
||||
// Check that there's still a starting worker process
|
||||
// before all workers have been registered
|
||||
ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), 1);
|
||||
// Check that we cannot lookup the worker before it's registered.
|
||||
ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker_id), nullptr);
|
||||
RAY_CHECK_OK(worker_pool_.RegisterWorker(worker_id, worker));
|
||||
ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), nullptr);
|
||||
RAY_CHECK_OK(worker_pool_.RegisterWorker(worker));
|
||||
// Check that we can lookup the worker after it's registered.
|
||||
ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker_id), worker);
|
||||
ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), worker);
|
||||
}
|
||||
// Check that there's no starting worker process
|
||||
ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), 0);
|
||||
for (const auto &worker : workers) {
|
||||
WorkerID worker_id = worker->GetWorkerId();
|
||||
worker_pool_.DisconnectWorker(worker);
|
||||
// Check that we cannot lookup the worker after it's disconnected.
|
||||
ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker_id), nullptr);
|
||||
ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,20 +6,14 @@ namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
void GrpcServer::Run() {
|
||||
std::string server_address;
|
||||
// Set unix domain socket or tcp address.
|
||||
if (!unix_socket_path_.empty()) {
|
||||
server_address = "unix://" + unix_socket_path_;
|
||||
} else {
|
||||
server_address = "0.0.0.0:" + std::to_string(port_);
|
||||
}
|
||||
std::string server_address("0.0.0.0:" + std::to_string(port_));
|
||||
|
||||
grpc::ServerBuilder builder;
|
||||
// TODO(hchen): Add options for authentication.
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_);
|
||||
// Register all the services to this server.
|
||||
if (services_.empty()) {
|
||||
RAY_LOG(WARNING) << "No service found when start grpc server " << name_;
|
||||
RAY_LOG(WARNING) << "No service is found when start grpc server " << name_;
|
||||
}
|
||||
for (auto &entry : services_) {
|
||||
builder.RegisterService(&entry.get());
|
||||
@@ -29,11 +23,7 @@ void GrpcServer::Run() {
|
||||
cq_ = builder.AddCompletionQueue();
|
||||
// Build and start server.
|
||||
server_ = builder.BuildAndStart();
|
||||
if (unix_socket_path_.empty()) {
|
||||
// For a TCP-based server, the actual port is decided after `AddListeningPort`.
|
||||
server_address = "0.0.0.0:" + std::to_string(port_);
|
||||
}
|
||||
RAY_LOG(INFO) << name_ << " server started, listening on " << server_address;
|
||||
RAY_LOG(INFO) << name_ << " server started, listening on port " << port_ << ".";
|
||||
|
||||
// Create calls for all the server call factories.
|
||||
for (auto &entry : server_call_factories_and_concurrencies_) {
|
||||
|
||||
@@ -32,16 +32,7 @@ class GrpcServer {
|
||||
/// \param[in] port The port to bind this server to. If it's 0, a random available port
|
||||
/// will be chosen.
|
||||
GrpcServer(std::string name, const uint32_t port)
|
||||
: name_(std::move(name)), port_(port), is_closed_(false) {}
|
||||
|
||||
/// Construct a gRPC server that listens on unix domain socket.
|
||||
///
|
||||
/// \param[in] name Name of this server, used for logging and debugging purpose.
|
||||
/// \param[in] unix_socket_path Unix domain socket full path.
|
||||
GrpcServer(std::string name, const std::string &unix_socket_path)
|
||||
: GrpcServer(std::move(name), 0) {
|
||||
unix_socket_path_ = unix_socket_path;
|
||||
}
|
||||
: name_(std::move(name)), port_(port), is_closed_(true) {}
|
||||
|
||||
/// Destruct this gRPC server.
|
||||
~GrpcServer() { Shutdown(); }
|
||||
@@ -82,8 +73,6 @@ class GrpcServer {
|
||||
int port_;
|
||||
/// Indicates whether this server has been closed.
|
||||
bool is_closed_;
|
||||
/// Unix domain socket path.
|
||||
std::string unix_socket_path_;
|
||||
/// The `grpc::Service` objects which should be registered to `ServerBuilder`.
|
||||
std::vector<std::reference_wrapper<grpc::Service>> services_;
|
||||
/// The `ServerCallFactory` objects, and the maximum number of concurrent requests that
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
|
||||
#include "ray/common/grpc_util.h"
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/rpc/client_call.h"
|
||||
#include "ray/util/logging.h"
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
|
||||
#include "ray/common/grpc_util.h"
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/util/logging.h"
|
||||
#include "src/ray/protobuf/object_manager.grpc.pb.h"
|
||||
|
||||
@@ -1,440 +0,0 @@
|
||||
|
||||
#include "src/ray/rpc/raylet/raylet_client.h"
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
#define RETURN_IF_DISCONNECTED(connected) \
|
||||
do { \
|
||||
if (!(connected)) { \
|
||||
return Status::Invalid("Raylet connection is closed."); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
RayletClient::RayletClient(const std::string &raylet_socket, const WorkerID &worker_id,
|
||||
bool is_worker, const JobID &job_id, const Language &language,
|
||||
int port)
|
||||
: worker_id_(worker_id),
|
||||
is_worker_(is_worker),
|
||||
job_id_(job_id),
|
||||
language_(language),
|
||||
port_(port),
|
||||
main_service_(),
|
||||
work_(main_service_),
|
||||
client_call_manager_(main_service_),
|
||||
heartbeat_timer_(main_service_),
|
||||
is_connected_(false) {
|
||||
std::shared_ptr<grpc::Channel> channel =
|
||||
grpc::CreateChannel("unix://" + raylet_socket, grpc::InsecureChannelCredentials());
|
||||
stub_ = RayletService::NewStub(channel);
|
||||
|
||||
rpc_thread_ = std::thread([this]() { main_service_.run(); });
|
||||
RAY_LOG(DEBUG) << "Connecting to unix socket: "
|
||||
<< "unix://" + raylet_socket
|
||||
<< ", is worker: " << (is_worker_ ? "true" : "false")
|
||||
<< ", worker id: " << worker_id;
|
||||
// Try to register client `num_raylet_client_retry_times` times.
|
||||
TryRegisterClient(RayConfig::instance().num_raylet_client_retry_times());
|
||||
}
|
||||
|
||||
void RayletClient::TryRegisterClient(int retry_times) {
|
||||
// We should block here until register succeeds.
|
||||
for (int i = 0; i < retry_times; i++) {
|
||||
auto st = RegisterClient();
|
||||
if (st.ok()) {
|
||||
is_connected_ = true;
|
||||
Heartbeat();
|
||||
return;
|
||||
}
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(500));
|
||||
}
|
||||
RAY_LOG(FATAL) << "Worker " << worker_id_
|
||||
<< " failed to register to raylet server, worker id: " << worker_id_
|
||||
<< ", pid: " << static_cast<int>(getpid())
|
||||
<< ", is worker: " << is_worker_;
|
||||
}
|
||||
|
||||
RayletClient::~RayletClient() {
|
||||
is_connected_ = false;
|
||||
main_service_.stop();
|
||||
rpc_thread_.join();
|
||||
}
|
||||
|
||||
ray::Status RayletClient::Disconnect() {
|
||||
DisconnectClientRequest disconnect_client_request;
|
||||
disconnect_client_request.set_worker_id(worker_id_.Binary());
|
||||
|
||||
DisconnectClientReply reply;
|
||||
grpc::ClientContext context;
|
||||
auto status = stub_->DisconnectClient(&context, disconnect_client_request, &reply);
|
||||
if (!status.ok()) {
|
||||
RAY_LOG(ERROR) << "Worker " << worker_id_
|
||||
<< " failed to disconnect from raylet, msg: "
|
||||
<< status.error_message();
|
||||
}
|
||||
return GrpcStatusToRayStatus(status);
|
||||
}
|
||||
|
||||
ray::Status RayletClient::SubmitTask(const ray::TaskSpecification &task_spec) {
|
||||
RETURN_IF_DISCONNECTED(is_connected_);
|
||||
SubmitTaskRequest submit_task_request;
|
||||
submit_task_request.set_worker_id(worker_id_.Binary());
|
||||
submit_task_request.mutable_task_spec()->CopyFrom(task_spec.GetMessage());
|
||||
|
||||
auto callback = [this](const Status &status, const SubmitTaskReply &reply) {
|
||||
if (!status.ok() && is_connected_) {
|
||||
is_connected_ = false;
|
||||
RAY_LOG(INFO) << "Worker " << worker_id_
|
||||
<< " failed to send SubmitTaskRequest, msg: " << status.message();
|
||||
}
|
||||
};
|
||||
|
||||
auto call =
|
||||
client_call_manager_.CreateCall<RayletService, SubmitTaskRequest, SubmitTaskReply>(
|
||||
*stub_, &RayletService::Stub::PrepareAsyncSubmitTask, submit_task_request,
|
||||
callback);
|
||||
return call->GetStatus();
|
||||
}
|
||||
|
||||
ray::Status RayletClient::GetTask(std::unique_ptr<ray::TaskSpecification> *task_spec) {
|
||||
RETURN_IF_DISCONNECTED(is_connected_);
|
||||
GetTaskRequest get_task_request;
|
||||
get_task_request.set_worker_id(worker_id_.Binary());
|
||||
|
||||
grpc::ClientContext context;
|
||||
GetTaskReply reply;
|
||||
// The actual RPC.
|
||||
auto status = stub_->GetTask(&context, get_task_request, &reply);
|
||||
|
||||
if (status.ok()) {
|
||||
resource_ids_.clear();
|
||||
// Parse resources that would be used by this assigned task.
|
||||
for (int64_t i = 0; i < reply.fractional_resource_ids().size(); ++i) {
|
||||
auto const &fractional_resource_ids = reply.fractional_resource_ids()[i];
|
||||
auto &acquired_resources = resource_ids_[fractional_resource_ids.resource_name()];
|
||||
|
||||
// Each resource includes a series of resource IDs (e.g., GPU 0) and corresponding
|
||||
// amount for that resource ID. If the resource amount is fractional, then there
|
||||
// should only be one resource ID.
|
||||
size_t num_resource_ids = fractional_resource_ids.resource_ids().size();
|
||||
size_t num_resource_fractions = fractional_resource_ids.resource_fractions().size();
|
||||
RAY_CHECK(num_resource_ids == num_resource_fractions);
|
||||
RAY_CHECK(num_resource_ids > 0);
|
||||
for (size_t j = 0; j < num_resource_ids; ++j) {
|
||||
int64_t resource_id = fractional_resource_ids.resource_ids()[j];
|
||||
double resource_fraction = fractional_resource_ids.resource_fractions()[j];
|
||||
if (num_resource_ids > 1) {
|
||||
int64_t whole_fraction = resource_fraction;
|
||||
RAY_CHECK(whole_fraction == resource_fraction);
|
||||
}
|
||||
acquired_resources.emplace_back(resource_id, resource_fraction);
|
||||
}
|
||||
}
|
||||
task_spec->reset(new ray::TaskSpecification(reply.task_spec()));
|
||||
} else {
|
||||
*task_spec = nullptr;
|
||||
RAY_LOG(INFO) << "Worker " << worker_id_
|
||||
<< " failed to get task, msg: " << status.error_message();
|
||||
}
|
||||
return GrpcStatusToRayStatus(status);
|
||||
}
|
||||
|
||||
ray::Status RayletClient::TaskDone() {
|
||||
RETURN_IF_DISCONNECTED(is_connected_);
|
||||
TaskDoneRequest task_done_request;
|
||||
task_done_request.set_worker_id(worker_id_.Binary());
|
||||
|
||||
auto callback = [this](const Status &status, const TaskDoneReply &reply) {
|
||||
if (!status.ok() && is_connected_) {
|
||||
is_connected_ = false;
|
||||
RAY_LOG(INFO) << "Worker " << worker_id_
|
||||
<< " failed to send TaskDoneRequest, msg: " << status.message();
|
||||
}
|
||||
};
|
||||
|
||||
auto call =
|
||||
client_call_manager_.CreateCall<RayletService, TaskDoneRequest, TaskDoneReply>(
|
||||
*stub_, &RayletService::Stub::PrepareAsyncTaskDone, task_done_request,
|
||||
callback);
|
||||
return call->GetStatus();
|
||||
}
|
||||
|
||||
ray::Status RayletClient::FetchOrReconstruct(const std::vector<ObjectID> &object_ids,
|
||||
bool fetch_only,
|
||||
const TaskID ¤t_task_id) {
|
||||
RETURN_IF_DISCONNECTED(is_connected_);
|
||||
FetchOrReconstructRequest fetch_or_reconstruct_request;
|
||||
fetch_or_reconstruct_request.set_fetch_only(fetch_only);
|
||||
fetch_or_reconstruct_request.set_task_id(current_task_id.Binary());
|
||||
fetch_or_reconstruct_request.set_worker_id(worker_id_.Binary());
|
||||
IdVectorToProtobuf<ObjectID, FetchOrReconstructRequest>(
|
||||
object_ids, fetch_or_reconstruct_request,
|
||||
&FetchOrReconstructRequest::add_object_ids);
|
||||
|
||||
// Callback to deal with reply.
|
||||
auto callback = [this](const Status &status, const FetchOrReconstructReply &reply) {
|
||||
if (!status.ok() && is_connected_) {
|
||||
is_connected_ = false;
|
||||
RAY_LOG(INFO) << "Worker " << worker_id_
|
||||
<< " failed to send FetchOrReconstructRequest, msg: "
|
||||
<< status.message();
|
||||
}
|
||||
};
|
||||
|
||||
auto call =
|
||||
client_call_manager_
|
||||
.CreateCall<RayletService, FetchOrReconstructRequest, FetchOrReconstructReply>(
|
||||
*stub_, &RayletService::Stub::PrepareAsyncFetchOrReconstruct,
|
||||
fetch_or_reconstruct_request, callback);
|
||||
return call->GetStatus();
|
||||
}
|
||||
|
||||
ray::Status RayletClient::NotifyUnblocked(const TaskID ¤t_task_id) {
|
||||
RETURN_IF_DISCONNECTED(is_connected_);
|
||||
NotifyUnblockedRequest notify_unblocked_request;
|
||||
notify_unblocked_request.set_worker_id(worker_id_.Binary());
|
||||
notify_unblocked_request.set_task_id(current_task_id.Binary());
|
||||
|
||||
auto callback = [this](const Status &status, const NotifyUnblockedReply &reply) {
|
||||
if (!status.ok() && is_connected_) {
|
||||
is_connected_ = false;
|
||||
RAY_LOG(INFO) << "Worker " << worker_id_
|
||||
<< " failed to send NotifyUnblockedRequest, msg: "
|
||||
<< status.message();
|
||||
}
|
||||
};
|
||||
|
||||
auto call =
|
||||
client_call_manager_
|
||||
.CreateCall<RayletService, NotifyUnblockedRequest, NotifyUnblockedReply>(
|
||||
*stub_, &RayletService::Stub::PrepareAsyncNotifyUnblocked,
|
||||
notify_unblocked_request, callback);
|
||||
return call->GetStatus();
|
||||
}
|
||||
|
||||
ray::Status RayletClient::Wait(const std::vector<ObjectID> &object_ids, int num_returns,
|
||||
int64_t timeout_milliseconds, bool wait_local,
|
||||
const TaskID ¤t_task_id, WaitResultPair *result) {
|
||||
RETURN_IF_DISCONNECTED(is_connected_);
|
||||
WaitRequest wait_request;
|
||||
wait_request.set_worker_id(worker_id_.Binary());
|
||||
wait_request.set_timeout(timeout_milliseconds);
|
||||
wait_request.set_wait_local(wait_local);
|
||||
wait_request.set_task_id(current_task_id.Binary());
|
||||
wait_request.set_num_ready_objects(num_returns);
|
||||
IdVectorToProtobuf<ObjectID, WaitRequest>(object_ids, wait_request,
|
||||
&WaitRequest::add_object_ids);
|
||||
|
||||
grpc::ClientContext context;
|
||||
WaitReply reply;
|
||||
auto status = stub_->Wait(&context, wait_request, &reply);
|
||||
|
||||
if (status.ok()) {
|
||||
result->first = IdVectorFromProtobuf<ObjectID>(reply.found());
|
||||
result->second = IdVectorFromProtobuf<ObjectID>(reply.remaining());
|
||||
} else {
|
||||
RAY_LOG(INFO) << "Worker " << worker_id_
|
||||
<< " failed to send WaitRequest, msg: " << status.error_message();
|
||||
}
|
||||
|
||||
return GrpcStatusToRayStatus(status);
|
||||
}
|
||||
|
||||
ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string &type,
|
||||
const std::string &error_message, double timestamp) {
|
||||
RETURN_IF_DISCONNECTED(is_connected_);
|
||||
PushErrorRequest push_error_request;
|
||||
push_error_request.set_job_id(job_id.Binary());
|
||||
push_error_request.set_type(type);
|
||||
push_error_request.set_error_message(error_message);
|
||||
push_error_request.set_timestamp(timestamp);
|
||||
push_error_request.set_worker_id(worker_id_.Binary());
|
||||
|
||||
auto callback = [this](const Status &status, const PushErrorReply &reply) {
|
||||
if (!status.ok() && is_connected_) {
|
||||
is_connected_ = false;
|
||||
RAY_LOG(INFO) << "Worker " << worker_id_
|
||||
<< " failed to send PushErrorRequest, msg: " << status.message();
|
||||
}
|
||||
};
|
||||
|
||||
auto call =
|
||||
client_call_manager_.CreateCall<RayletService, PushErrorRequest, PushErrorReply>(
|
||||
*stub_, &RayletService::Stub::PrepareAsyncPushError, push_error_request,
|
||||
callback);
|
||||
return call->GetStatus();
|
||||
}
|
||||
|
||||
ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_events) {
|
||||
RETURN_IF_DISCONNECTED(is_connected_);
|
||||
PushProfileEventsRequest push_profile_events_request;
|
||||
push_profile_events_request.mutable_profile_table_data()->CopyFrom(profile_events);
|
||||
push_profile_events_request.set_worker_id(worker_id_.Binary());
|
||||
|
||||
auto callback = [this](const Status &status, const PushProfileEventsReply &reply) {
|
||||
if (!status.ok() && is_connected_) {
|
||||
is_connected_ = false;
|
||||
RAY_LOG(INFO) << "Worker " << worker_id_
|
||||
<< " failed to send PushProfileEventsRequest, msg: "
|
||||
<< status.message();
|
||||
}
|
||||
};
|
||||
|
||||
auto call =
|
||||
client_call_manager_
|
||||
.CreateCall<RayletService, PushProfileEventsRequest, PushProfileEventsReply>(
|
||||
*stub_, &RayletService::Stub::PrepareAsyncPushProfileEvents,
|
||||
push_profile_events_request, callback);
|
||||
return call->GetStatus();
|
||||
}
|
||||
|
||||
ray::Status RayletClient::FreeObjects(const std::vector<ray::ObjectID> &object_ids,
|
||||
bool local_only, bool delete_creating_tasks) {
|
||||
RETURN_IF_DISCONNECTED(is_connected_);
|
||||
FreeObjectsInStoreRequest free_objects_request;
|
||||
free_objects_request.set_local_only(local_only);
|
||||
free_objects_request.set_delete_creating_tasks(delete_creating_tasks);
|
||||
free_objects_request.set_worker_id(worker_id_.Binary());
|
||||
IdVectorToProtobuf<ray::ObjectID, FreeObjectsInStoreRequest>(
|
||||
object_ids, free_objects_request, &FreeObjectsInStoreRequest::add_object_ids);
|
||||
|
||||
auto callback = [this](const Status &status, const FreeObjectsInStoreReply &reply) {
|
||||
if (!status.ok() && is_connected_) {
|
||||
is_connected_ = false;
|
||||
RAY_LOG(INFO) << "Worker " << worker_id_
|
||||
<< " failed to send FreeObjectsInStoreRequest, msg: "
|
||||
<< status.message();
|
||||
}
|
||||
};
|
||||
|
||||
auto call =
|
||||
client_call_manager_
|
||||
.CreateCall<RayletService, FreeObjectsInStoreRequest, FreeObjectsInStoreReply>(
|
||||
*stub_, &RayletService::Stub::PrepareAsyncFreeObjectsInStore,
|
||||
free_objects_request, callback);
|
||||
return call->GetStatus();
|
||||
}
|
||||
|
||||
ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id,
|
||||
ActorCheckpointID &checkpoint_id) {
|
||||
RETURN_IF_DISCONNECTED(is_connected_);
|
||||
PrepareActorCheckpointRequest prepare_actor_checkpoint_request;
|
||||
prepare_actor_checkpoint_request.set_actor_id(actor_id.Binary());
|
||||
prepare_actor_checkpoint_request.set_worker_id(worker_id_.Binary());
|
||||
|
||||
grpc::ClientContext context;
|
||||
PrepareActorCheckpointReply reply;
|
||||
auto status =
|
||||
stub_->PrepareActorCheckpoint(&context, prepare_actor_checkpoint_request, &reply);
|
||||
|
||||
if (status.ok()) {
|
||||
checkpoint_id = ActorCheckpointID::FromBinary(reply.checkpoint_id());
|
||||
} else {
|
||||
RAY_LOG(INFO) << "Worker " << worker_id_
|
||||
<< " failed to send PrepareActorCheckpointRequest, msg: "
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
return GrpcStatusToRayStatus(status);
|
||||
}
|
||||
|
||||
ray::Status RayletClient::NotifyActorResumedFromCheckpoint(
|
||||
const ActorID &actor_id, const ActorCheckpointID &checkpoint_id) {
|
||||
RETURN_IF_DISCONNECTED(is_connected_);
|
||||
NotifyActorResumedFromCheckpointRequest notify_actor_resumed_from_checkpoint_request;
|
||||
notify_actor_resumed_from_checkpoint_request.set_actor_id(actor_id.Binary());
|
||||
notify_actor_resumed_from_checkpoint_request.set_checkpoint_id(checkpoint_id.Binary());
|
||||
notify_actor_resumed_from_checkpoint_request.set_worker_id(worker_id_.Binary());
|
||||
|
||||
auto callback = [this](const Status &status,
|
||||
const NotifyActorResumedFromCheckpointReply &reply) {
|
||||
if (!status.ok() && is_connected_) {
|
||||
is_connected_ = false;
|
||||
RAY_LOG(INFO) << "NotifyActorResumedFromCheckpoint failed, msg: "
|
||||
<< status.message();
|
||||
}
|
||||
};
|
||||
|
||||
auto call =
|
||||
client_call_manager_
|
||||
.CreateCall<RayletService, NotifyActorResumedFromCheckpointRequest,
|
||||
NotifyActorResumedFromCheckpointReply>(
|
||||
*stub_, &RayletService::Stub::PrepareAsyncNotifyActorResumedFromCheckpoint,
|
||||
notify_actor_resumed_from_checkpoint_request, callback);
|
||||
return call->GetStatus();
|
||||
}
|
||||
|
||||
ray::Status RayletClient::SetResource(const std::string &resource_name,
|
||||
const double capacity,
|
||||
const ray::ClientID &client_id) {
|
||||
RETURN_IF_DISCONNECTED(is_connected_);
|
||||
SetResourceRequest set_resource_request;
|
||||
set_resource_request.set_resource_name(resource_name);
|
||||
set_resource_request.set_capacity(capacity);
|
||||
set_resource_request.set_client_id(client_id.Binary());
|
||||
set_resource_request.set_worker_id(worker_id_.Binary());
|
||||
|
||||
auto callback = [this](const Status &status, const SetResourceReply &reply) {
|
||||
if (!status.ok() && is_connected_) {
|
||||
is_connected_ = false;
|
||||
RAY_LOG(INFO) << "SetResource failed, msg: " << status.message();
|
||||
}
|
||||
};
|
||||
|
||||
auto call = client_call_manager_
|
||||
.CreateCall<RayletService, SetResourceRequest, SetResourceReply>(
|
||||
*stub_, &RayletService::Stub::PrepareAsyncSetResource,
|
||||
set_resource_request, callback);
|
||||
return call->GetStatus();
|
||||
}
|
||||
|
||||
ray::Status RayletClient::RegisterClient() {
|
||||
RegisterClientRequest register_client_request;
|
||||
register_client_request.set_is_worker(is_worker_);
|
||||
register_client_request.set_worker_id(worker_id_.Binary());
|
||||
register_client_request.set_worker_pid(getpid());
|
||||
register_client_request.set_job_id(job_id_.Binary());
|
||||
register_client_request.set_language(language_);
|
||||
register_client_request.set_port(port_);
|
||||
|
||||
grpc::ClientContext context;
|
||||
RegisterClientReply reply;
|
||||
auto status = stub_->RegisterClient(&context, register_client_request, &reply);
|
||||
|
||||
if (!status.ok()) {
|
||||
RAY_LOG(DEBUG) << "Worker " << worker_id_
|
||||
<< " failed to register client, msg: " << status.error_message();
|
||||
}
|
||||
|
||||
return GrpcStatusToRayStatus(status);
|
||||
}
|
||||
|
||||
void RayletClient::Heartbeat() {
|
||||
if (!is_connected_) {
|
||||
return;
|
||||
}
|
||||
HeartbeatRequest heartbeat_request;
|
||||
heartbeat_request.set_is_worker(is_worker_);
|
||||
heartbeat_request.set_worker_id(worker_id_.Binary());
|
||||
|
||||
auto callback = [this](const Status &status, const HeartbeatReply &reply) {
|
||||
if (!status.ok() && is_connected_) {
|
||||
is_connected_ = false;
|
||||
RAY_LOG(INFO) << "Heartbeat failed, msg: " << status.message();
|
||||
}
|
||||
};
|
||||
auto call =
|
||||
client_call_manager_.CreateCall<RayletService, HeartbeatRequest, HeartbeatReply>(
|
||||
*stub_, &RayletService::Stub::PrepareAsyncHeartbeat, heartbeat_request,
|
||||
callback);
|
||||
|
||||
heartbeat_timer_.expires_from_now(boost::posix_time::milliseconds(
|
||||
RayConfig::instance().heartbeat_timeout_milliseconds()));
|
||||
heartbeat_timer_.async_wait([this](const boost::system::error_code &error) {
|
||||
RAY_CHECK(!error);
|
||||
Heartbeat();
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
@@ -1,258 +0,0 @@
|
||||
#ifndef RAY_RPC_RAYLET_SERVER_H
|
||||
#define RAY_RPC_RAYLET_SERVER_H
|
||||
|
||||
#include "src/ray/rpc/grpc_server.h"
|
||||
#include "src/ray/rpc/server_call.h"
|
||||
|
||||
#include "src/ray/protobuf/raylet.grpc.pb.h"
|
||||
#include "src/ray/protobuf/raylet.pb.h"
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
/// Implementations of the `RayletService`, check interface in
|
||||
/// `src/ray/protobuf/raylet.proto`.
|
||||
class RayletServiceHandler {
|
||||
public:
|
||||
virtual void HandleRegisterClientRequest(const RegisterClientRequest &request,
|
||||
RegisterClientReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
/// Handle a `SubmitTask` request.
|
||||
/// The implementation can handle this request asynchronously. When handling is done,
|
||||
/// the `send_reply_callback` should be called.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[out] reply The reply message.
|
||||
/// \param[in] send_reply_callback The callback to be called when the request is done.
|
||||
virtual void HandleSubmitTaskRequest(const SubmitTaskRequest &request,
|
||||
SubmitTaskReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
/// Handle a `DisconnectClient` request.
|
||||
virtual void HandleDisconnectClientRequest(const DisconnectClientRequest &request,
|
||||
DisconnectClientReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
/// Handle a `GetTask` request.
|
||||
virtual void HandleGetTaskRequest(const GetTaskRequest &request, GetTaskReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
/// Handle a `TaskDone` request.
|
||||
virtual void HandleTaskDoneRequest(const TaskDoneRequest &request, TaskDoneReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
/// Handle a `HandleFetchOrReconstruct` request.
|
||||
virtual void HandleFetchOrReconstructRequest(const FetchOrReconstructRequest &request,
|
||||
FetchOrReconstructReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
/// Handle a `HandleNotifyUnblocked` request.
|
||||
virtual void HandleNotifyUnblockedRequest(const NotifyUnblockedRequest &request,
|
||||
NotifyUnblockedReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
/// Handle a `Wait` request.
|
||||
virtual void HandleWaitRequest(const WaitRequest &request, WaitReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
/// Handle a `PushError` request.
|
||||
virtual void HandlePushErrorRequest(const PushErrorRequest &request,
|
||||
PushErrorReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
/// Handle a `PushProfileEvents` request.
|
||||
virtual void HandlePushProfileEventsRequest(const PushProfileEventsRequest &request,
|
||||
PushProfileEventsReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
/// Handle a `FreeObjectsInStoreInObjectStore` request.
|
||||
virtual void HandleFreeObjectsInStoreRequest(const FreeObjectsInStoreRequest &request,
|
||||
FreeObjectsInStoreReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
/// Handle a `PrepareActorCheckpoint` request.
|
||||
virtual void HandlePrepareActorCheckpointRequest(
|
||||
const PrepareActorCheckpointRequest &request, PrepareActorCheckpointReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
/// Handle a `NotifyActorResumedFromCheckpoint` request.
|
||||
virtual void HandleNotifyActorResumedFromCheckpointRequest(
|
||||
const NotifyActorResumedFromCheckpointRequest &request,
|
||||
NotifyActorResumedFromCheckpointReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
/// Handle a `SetResource` request.
|
||||
virtual void HandleSetResourceRequest(const SetResourceRequest &request,
|
||||
SetResourceReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
/// Handle a `SetResourceReply` request.
|
||||
virtual void HandleHeartbeatRequest(const HeartbeatRequest &request,
|
||||
HeartbeatReply *reply,
|
||||
SendReplyCallback send_reply_callback) = 0;
|
||||
};
|
||||
|
||||
/// The `GrpcService` for `RayletGrpcService`.
|
||||
class RayletGrpcService : public GrpcService {
|
||||
public:
|
||||
/// Construct a `RayletGrpcService`.
|
||||
///
|
||||
/// \param[in] io_service Service used to handle incoming requests
|
||||
/// \param[in] handler The service handler that actually handle the requests.
|
||||
RayletGrpcService(boost::asio::io_service &io_service,
|
||||
RayletServiceHandler &service_handler)
|
||||
: GrpcService(io_service), service_handler_(service_handler){};
|
||||
|
||||
protected:
|
||||
grpc::Service &GetGrpcService() override { return service_; }
|
||||
|
||||
void InitServerCallFactories(
|
||||
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
|
||||
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
|
||||
*server_call_factories_and_concurrencies) override {
|
||||
// Initialize the factory for `RegisterClient` requests.
|
||||
std::unique_ptr<ServerCallFactory> register_client_call_factory(
|
||||
new ServerCallFactoryImpl<RayletService, RayletServiceHandler,
|
||||
RegisterClientRequest, RegisterClientReply>(
|
||||
service_, &RayletService::AsyncService::RequestRegisterClient,
|
||||
service_handler_, &RayletServiceHandler::HandleRegisterClientRequest, cq,
|
||||
main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(register_client_call_factory), 10);
|
||||
|
||||
// Initialize the factory for `SubmitTask` requests.
|
||||
std::unique_ptr<ServerCallFactory> submit_task_call_factory(
|
||||
new ServerCallFactoryImpl<RayletService, RayletServiceHandler, SubmitTaskRequest,
|
||||
SubmitTaskReply>(
|
||||
service_, &RayletService::AsyncService::RequestSubmitTask, service_handler_,
|
||||
&RayletServiceHandler::HandleSubmitTaskRequest, cq, main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(submit_task_call_factory), 20);
|
||||
|
||||
// Initialize the factory for `DisconnectClient` requests.
|
||||
std::unique_ptr<ServerCallFactory> disconnect_client_call_factory(
|
||||
new ServerCallFactoryImpl<RayletService, RayletServiceHandler,
|
||||
DisconnectClientRequest, DisconnectClientReply>(
|
||||
service_, &RayletService::AsyncService::RequestDisconnectClient,
|
||||
service_handler_, &RayletServiceHandler::HandleDisconnectClientRequest, cq,
|
||||
main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(disconnect_client_call_factory), 10);
|
||||
|
||||
// Initialize the factory for `GetTask` requests.
|
||||
std::unique_ptr<ServerCallFactory> get_task_call_factory(
|
||||
new ServerCallFactoryImpl<RayletService, RayletServiceHandler, GetTaskRequest,
|
||||
GetTaskReply>(
|
||||
service_, &RayletService::AsyncService::RequestGetTask, service_handler_,
|
||||
&RayletServiceHandler::HandleGetTaskRequest, cq, main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(get_task_call_factory), 20);
|
||||
|
||||
// Initialize the factory for `TaskDone` requests.
|
||||
std::unique_ptr<ServerCallFactory> task_done_call_factory(
|
||||
new ServerCallFactoryImpl<RayletService, RayletServiceHandler, TaskDoneRequest,
|
||||
TaskDoneReply>(
|
||||
service_, &RayletService::AsyncService::RequestTaskDone, service_handler_,
|
||||
&RayletServiceHandler::HandleTaskDoneRequest, cq, main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(task_done_call_factory), 20);
|
||||
|
||||
// Initialize the factory for `FetchOrReconstruct` requests.
|
||||
std::unique_ptr<ServerCallFactory> fetch_or_reconstruct_call_factory(
|
||||
new ServerCallFactoryImpl<RayletService, RayletServiceHandler,
|
||||
FetchOrReconstructRequest, FetchOrReconstructReply>(
|
||||
service_, &RayletService::AsyncService::RequestFetchOrReconstruct,
|
||||
service_handler_, &RayletServiceHandler::HandleFetchOrReconstructRequest, cq,
|
||||
main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(fetch_or_reconstruct_call_factory), 10);
|
||||
|
||||
// Initialize the factory for `NotifyUnblocked` requests.
|
||||
std::unique_ptr<ServerCallFactory> notify_unblocked_call_factory(
|
||||
new ServerCallFactoryImpl<RayletService, RayletServiceHandler,
|
||||
NotifyUnblockedRequest, NotifyUnblockedReply>(
|
||||
service_, &RayletService::AsyncService::RequestNotifyUnblocked,
|
||||
service_handler_, &RayletServiceHandler::HandleNotifyUnblockedRequest, cq,
|
||||
main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(notify_unblocked_call_factory), 10);
|
||||
|
||||
// Initialize the factory for `Wait` requests.
|
||||
std::unique_ptr<ServerCallFactory> wait_call_factory(
|
||||
new ServerCallFactoryImpl<RayletService, RayletServiceHandler, WaitRequest,
|
||||
WaitReply>(
|
||||
service_, &RayletService::AsyncService::RequestWait, service_handler_,
|
||||
&RayletServiceHandler::HandleWaitRequest, cq, main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(std::move(wait_call_factory),
|
||||
20);
|
||||
|
||||
// Initialize the factory for `PushError` requests.
|
||||
std::unique_ptr<ServerCallFactory> push_error_call_factory(
|
||||
new ServerCallFactoryImpl<RayletService, RayletServiceHandler, PushErrorRequest,
|
||||
PushErrorReply>(
|
||||
service_, &RayletService::AsyncService::RequestPushError, service_handler_,
|
||||
&RayletServiceHandler::HandlePushErrorRequest, cq, main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(push_error_call_factory), 10);
|
||||
|
||||
// Initialize the factory for `PushProfileEvents` requests.
|
||||
std::unique_ptr<ServerCallFactory> push_profile_events_call_factory(
|
||||
new ServerCallFactoryImpl<RayletService, RayletServiceHandler,
|
||||
PushProfileEventsRequest, PushProfileEventsReply>(
|
||||
service_, &RayletService::AsyncService::RequestPushProfileEvents,
|
||||
service_handler_, &RayletServiceHandler::HandlePushProfileEventsRequest, cq,
|
||||
main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(push_profile_events_call_factory), 10);
|
||||
|
||||
// Initialize the factory for `FreeObjectsInStore` requests.
|
||||
std::unique_ptr<ServerCallFactory> free_objects_call_factory(
|
||||
new ServerCallFactoryImpl<RayletService, RayletServiceHandler,
|
||||
FreeObjectsInStoreRequest, FreeObjectsInStoreReply>(
|
||||
service_, &RayletService::AsyncService::RequestFreeObjectsInStore,
|
||||
service_handler_, &RayletServiceHandler::HandleFreeObjectsInStoreRequest, cq,
|
||||
main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(free_objects_call_factory), 10);
|
||||
|
||||
// Initialize the factory for `PrepareActorCheckpoint` requests.
|
||||
std::unique_ptr<ServerCallFactory> prepare_actor_checkpoint_call_factory(
|
||||
new ServerCallFactoryImpl<RayletService, RayletServiceHandler,
|
||||
PrepareActorCheckpointRequest,
|
||||
PrepareActorCheckpointReply>(
|
||||
service_, &RayletService::AsyncService::RequestPrepareActorCheckpoint,
|
||||
service_handler_, &RayletServiceHandler::HandlePrepareActorCheckpointRequest,
|
||||
cq, main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(prepare_actor_checkpoint_call_factory), 10);
|
||||
|
||||
// Initialize the factory for `NotifyActorResumedFromCheckpoint` requests.
|
||||
std::unique_ptr<ServerCallFactory> notify_actor_resumed_from_checkpoint_call_factory(
|
||||
new ServerCallFactoryImpl<RayletService, RayletServiceHandler,
|
||||
NotifyActorResumedFromCheckpointRequest,
|
||||
NotifyActorResumedFromCheckpointReply>(
|
||||
service_,
|
||||
&RayletService::AsyncService::RequestNotifyActorResumedFromCheckpoint,
|
||||
service_handler_,
|
||||
&RayletServiceHandler::HandleNotifyActorResumedFromCheckpointRequest, cq,
|
||||
main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(notify_actor_resumed_from_checkpoint_call_factory), 10);
|
||||
|
||||
// Initialize the factory for `SetResource` requests.
|
||||
std::unique_ptr<ServerCallFactory> set_resource_call_factory(
|
||||
new ServerCallFactoryImpl<RayletService, RayletServiceHandler, SetResourceRequest,
|
||||
SetResourceReply>(
|
||||
service_, &RayletService::AsyncService::RequestSetResource, service_handler_,
|
||||
&RayletServiceHandler::HandleSetResourceRequest, cq, main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(set_resource_call_factory), 10);
|
||||
|
||||
// Initialize the factory for `Heartbeat` requests.
|
||||
std::unique_ptr<ServerCallFactory> heartbeat_call_factory(
|
||||
new ServerCallFactoryImpl<RayletService, RayletServiceHandler, HeartbeatRequest,
|
||||
HeartbeatReply>(
|
||||
service_, &RayletService::AsyncService::RequestHeartbeat, service_handler_,
|
||||
&RayletServiceHandler::HandleHeartbeatRequest, cq, main_service_));
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(heartbeat_call_factory), 10);
|
||||
}
|
||||
|
||||
private:
|
||||
/// The grpc async service object.
|
||||
RayletService::AsyncService service_;
|
||||
/// The service handler that actually handle the requests.
|
||||
RayletServiceHandler &service_handler_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user