From 599cc2be600761cbfdff3fc1b5da15a56b488c6d Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sat, 17 Aug 2019 19:11:32 -0700 Subject: [PATCH] Revert raylet to worker GRPC communication back to asio (#5450) --- BUILD.bazel | 55 +- .../main/java/org/ray/api/test/BaseTest.java | 6 +- .../ray/api/test/ResourcesManagementTest.java | 1 - python/ray/_raylet.pyx | 2 +- python/ray/includes/common.pxd | 10 +- python/ray/includes/libraylet.pxd | 12 +- src/ray/common/grpc_util.h | 11 - src/ray/common/ray_config_def.h | 9 +- src/ray/common/task/scheduling_resources.cc | 35 +- src/ray/common/task/scheduling_resources.h | 16 +- src/ray/common/task/task.h | 2 +- src/ray/core_worker/common.h | 2 +- src/ray/core_worker/core_worker.h | 4 +- ...g_ray_runtime_raylet_NativeRayletClient.cc | 4 +- src/ray/core_worker/object_interface.h | 2 - .../store_provider/local_plasma_provider.h | 2 +- .../store_provider/plasma_store_provider.h | 4 +- src/ray/core_worker/task_interface.h | 2 - src/ray/core_worker/test/core_worker_test.cc | 2 +- .../core_worker/transport/raylet_transport.h | 4 +- src/ray/object_manager/object_manager.cc | 5 +- src/ray/protobuf/common.proto | 12 - src/ray/protobuf/raylet.proto | 215 ----- src/ray/protobuf/worker.proto | 4 +- ...org_ray_runtime_raylet_RayletClientImpl.cc | 292 +++++++ src/ray/raylet/lineage_cache_test.cc | 1 + src/ray/raylet/node_manager.cc | 733 +++++++++--------- src/ray/raylet/node_manager.h | 135 +--- src/ray/raylet/raylet.cc | 61 +- src/ray/raylet/raylet.h | 10 +- src/ray/raylet/raylet_client.cc | 392 ++++++++++ src/ray/{rpc => }/raylet/raylet_client.h | 127 ++- src/ray/raylet/reconstruction_policy_test.cc | 2 +- src/ray/raylet/task_dependency_manager.h | 1 - src/ray/raylet/worker.cc | 64 +- src/ray/raylet/worker.h | 48 +- src/ray/raylet/worker_pool.cc | 106 ++- src/ray/raylet/worker_pool.h | 30 +- src/ray/raylet/worker_pool_test.cc | 30 +- src/ray/rpc/grpc_server.cc | 16 +- src/ray/rpc/grpc_server.h | 13 +- .../rpc/node_manager/node_manager_client.h | 1 - .../object_manager/object_manager_client.h | 1 - src/ray/rpc/raylet/raylet_client.cc | 440 ----------- src/ray/rpc/raylet/raylet_server.h | 258 ------ 45 files changed, 1418 insertions(+), 1764 deletions(-) delete mode 100644 src/ray/protobuf/raylet.proto create mode 100644 src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc create mode 100644 src/ray/raylet/raylet_client.cc rename src/ray/{rpc => }/raylet/raylet_client.h (70%) delete mode 100644 src/ray/rpc/raylet/raylet_client.cc delete mode 100644 src/ray/rpc/raylet/raylet_server.h diff --git a/BUILD.bazel b/BUILD.bazel index 05da0f0a7..32e57bf85 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -65,20 +65,6 @@ cc_proto_library( deps = [":object_manager_proto"], ) -proto_library( - name = "raylet_proto", - srcs = ["src/ray/protobuf/raylet.proto"], - deps = [ - ":common_proto", - ":gcs_proto", - ], -) - -cc_proto_library( - name = "raylet_cc_proto", - deps = [":raylet_proto"], -) - proto_library( name = "worker_proto", srcs = ["src/ray/protobuf/worker.proto"], @@ -182,34 +168,6 @@ cc_library( ], ) -# Raylet gRPC lib. -cc_grpc_library( - name = "raylet_cc_grpc", - srcs = [":raylet_proto"], - grpc_only = True, - deps = [":raylet_cc_proto"], -) - -# Raylet rpc server and client. -cc_library( - name = "raylet_rpc", - srcs = glob([ - "src/ray/rpc/raylet/*.cc", - ]), - hdrs = glob([ - "src/ray/rpc/raylet/*.h", - "src/ray/raylet/*.h", - ]), - copts = COPTS, - deps = [ - ":grpc_common_lib", - ":ray_common", - ":raylet_cc_grpc", - "@boost//:asio", - "@com_github_grpc_grpc//:grpc++", - ], -) - # Worker gRPC lib. cc_grpc_library( name = "worker_cc_grpc", @@ -263,6 +221,7 @@ cc_library( copts = COPTS, deps = [ ":common_cc_proto", + ":node_manager_fbs", ":ray_util", "@boost//:asio", "@com_github_grpc_grpc//:grpc++", @@ -361,11 +320,11 @@ cc_library( deps = [ ":common_cc_proto", ":gcs", + ":node_manager_fbs", ":node_manager_rpc", ":object_manager", ":ray_common", ":ray_util", - ":raylet_rpc", ":stats_lib", ":worker_rpc", "@boost//:asio", @@ -461,6 +420,7 @@ cc_test( srcs = ["src/ray/raylet/lineage_cache_test.cc"], copts = COPTS, deps = [ + ":node_manager_fbs", ":raylet_lib", "@com_google_googletest//:gtest_main", ], @@ -471,6 +431,7 @@ cc_test( srcs = ["src/ray/raylet/reconstruction_policy_test.cc"], copts = COPTS, deps = [ + ":node_manager_fbs", ":object_manager", ":raylet_lib", "@com_google_googletest//:gtest_main", @@ -670,6 +631,7 @@ cc_library( deps = [ ":gcs_cc_proto", ":hiredis", + ":node_manager_fbs", ":node_manager_rpc", ":ray_common", ":ray_util", @@ -725,6 +687,13 @@ flatbuffer_cc_library( out_prefix = "src/ray/common/", ) +flatbuffer_cc_library( + name = "node_manager_fbs", + srcs = ["src/ray/raylet/format/node_manager.fbs"], + flatc_args = FLATC_ARGS, + out_prefix = "src/ray/raylet/format/", +) + flatbuffer_cc_library( name = "object_manager_fbs", srcs = ["src/ray/object_manager/format/object_manager.fbs"], diff --git a/java/test/src/main/java/org/ray/api/test/BaseTest.java b/java/test/src/main/java/org/ray/api/test/BaseTest.java index 0170aa3a2..4c3973064 100644 --- a/java/test/src/main/java/org/ray/api/test/BaseTest.java +++ b/java/test/src/main/java/org/ray/api/test/BaseTest.java @@ -24,7 +24,11 @@ public class BaseTest { // These files need to be deleted after each test case. filesToDelete = ImmutableList.of( new File(Ray.getRuntimeContext().getRayletSocketName()), - new File(Ray.getRuntimeContext().getObjectStoreSocketName()) + new File(Ray.getRuntimeContext().getObjectStoreSocketName()), + // TODO(pcm): This is a workaround for the issue described + // in the PR description of https://github.com/ray-project/ray/pull/5450 + // and should be fixed properly. + new File("/tmp/ray/test/raylet_socket") ); // Make sure the files will be deleted even if the test doesn't exit gracefully. filesToDelete.forEach(File::deleteOnExit); diff --git a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java index dca559764..afae55be9 100644 --- a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java +++ b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java @@ -102,4 +102,3 @@ public class ResourcesManagementTest extends BaseTest { } } - diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index d9edd13f2..a1e9387fe 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -374,7 +374,7 @@ cdef class RayletClient: @property def client_id(self): - return ClientID(self.client.get().GetWorkerId().Binary()) + return ClientID(self.client.get().GetWorkerID().Binary()) @property def job_id(self): diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 97bb318a0..18a248304 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -88,16 +88,16 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil: cdef extern from "ray/protobuf/common.pb.h" nogil: - cdef cppclass CLanguage "ray::rpc::Language": + cdef cppclass CLanguage "Language": pass # This is a workaround for C++ enum class since Cython has no corresponding # representation. -cdef extern from "ray/protobuf/common.pb.h" namespace "ray::rpc::Language" nogil: - cdef CLanguage LANGUAGE_PYTHON "ray::rpc::Language::PYTHON" - cdef CLanguage LANGUAGE_CPP "ray::rpc::Language::CPP" - cdef CLanguage LANGUAGE_JAVA "ray::rpc::Language::JAVA" +cdef extern from "ray/protobuf/common.pb.h" namespace "Language" nogil: + cdef CLanguage LANGUAGE_PYTHON "Language::PYTHON" + cdef CLanguage LANGUAGE_CPP "Language::CPP" + cdef CLanguage LANGUAGE_JAVA "Language::JAVA" cdef extern from "ray/common/task/scheduling_resources.h" \ diff --git a/python/ray/includes/libraylet.pxd b/python/ray/includes/libraylet.pxd index c0ff3e614..45746da2f 100644 --- a/python/ray/includes/libraylet.pxd +++ b/python/ray/includes/libraylet.pxd @@ -23,14 +23,14 @@ from ray.includes.task cimport CTaskSpec cdef extern from "ray/protobuf/gcs.pb.h" nogil: - cdef cppclass GCSProfileEvent "ray::rpc::ProfileTableData::ProfileEvent": + cdef cppclass GCSProfileEvent "ProfileTableData::ProfileEvent": void set_event_type(const c_string &value) void set_start_time(double value) void set_end_time(double value) c_string set_extra_data(const c_string &value) GCSProfileEvent() - cdef cppclass GCSProfileTableData "ray::rpc::ProfileTableData": + cdef cppclass GCSProfileTableData "ProfileTableData": void set_component_type(const c_string &value) void set_component_id(const c_string &value) void set_node_ip_address(const c_string &value) @@ -43,12 +43,13 @@ ctypedef unordered_map[c_string, c_vector[pair[int64_t, double]]] \ ctypedef pair[c_vector[CObjectID], c_vector[CObjectID]] WaitResultPair -cdef extern from "ray/rpc/raylet/raylet_client.h" namespace "ray::rpc" nogil: - cdef cppclass CRayletClient "ray::rpc::RayletClient": +cdef extern from "ray/raylet/raylet_client.h" nogil: + cdef cppclass CRayletClient "RayletClient": CRayletClient(const c_string &raylet_socket, const CWorkerID &worker_id, c_bool is_worker, const CJobID &job_id, const CLanguage &language) + CRayStatus Disconnect() CRayStatus SubmitTask(const CTaskSpec &task_spec) CRayStatus GetTask(unique_ptr[CTaskSpec] *task_spec) CRayStatus TaskDone() @@ -72,8 +73,7 @@ cdef extern from "ray/rpc/raylet/raylet_client.h" namespace "ray::rpc" nogil: const CActorID &actor_id, const CActorCheckpointID &checkpoint_id) CRayStatus SetResource(const c_string &resource_name, const double capacity, const CClientID &client_Id) CLanguage GetLanguage() const - CWorkerID GetWorkerId() const + CWorkerID GetWorkerID() const CJobID GetJobID() const c_bool IsWorker() const - CRayStatus Disconnect() const ResourceMappingType &GetResourceIDs() const diff --git a/src/ray/common/grpc_util.h b/src/ray/common/grpc_util.h index edcb58b5e..1931b3937 100644 --- a/src/ray/common/grpc_util.h +++ b/src/ray/common/grpc_util.h @@ -82,17 +82,6 @@ inline std::vector VectorFromProtobuf( return std::vector(pb_repeated.begin(), pb_repeated.end()); } -template -using AddFunction = void (Message::*)(const ::std::string &value); -/// Add a vector of type ID to protobuf message. -template -inline void IdVectorToProtobuf(const std::vector &ids, Message &message, - AddFunction add_func) { - for (const auto &id : ids) { - (message.*add_func)(id.Binary()); - } -} - /// Converts a Protobuf `RepeatedField` to a vector of IDs. template inline std::vector IdVectorFromProtobuf( diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index b3b81a906..d0a6000c6 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -20,12 +20,8 @@ RAY_CONFIG(int64_t, ray_cookie, 0x5241590000000000) /// warning is logged that the handler is taking too long. RAY_CONFIG(int64_t, handler_warning_timeout_ms, 100) -/// The duration between heartbeats. This value is used for both worker and raylet. +/// The duration between heartbeats. These are sent by the raylet. RAY_CONFIG(int64_t, heartbeat_timeout_milliseconds, 100) -/// Worker heartbeats also use `heartbeat_timeout_milliseconds` as timer timeout period. -/// If a worker has not sent a heartbeat in the last `num_worker_heartbeats_timeout` -/// heartbeat intervals, raylet will mark this worker as dead. -RAY_CONFIG(int64_t, num_worker_heartbeats_timeout, 30) /// If a component has not sent a heartbeat in the last num_heartbeats_timeout /// heartbeat intervals, the raylet monitor process will report /// it as dead to the db_client table. @@ -156,9 +152,6 @@ RAY_CONFIG(int32_t, num_actor_checkpoints_to_keep, 20) /// Maximum number of ids in one batch to send to GCS to delete keys. RAY_CONFIG(uint32_t, maximum_gcs_deletion_batch_size, 1000) -/// Number of times for a raylet client to retry to register. -RAY_CONFIG(int, num_raylet_client_retry_times, 25) - /// When getting objects from object store, print a warning every this number of attempts. RAY_CONFIG(uint32_t, object_store_get_warn_per_num_attempts, 50) diff --git a/src/ray/common/task/scheduling_resources.cc b/src/ray/common/task/scheduling_resources.cc index 4668caff4..5463b0933 100644 --- a/src/ray/common/task/scheduling_resources.cc +++ b/src/ray/common/task/scheduling_resources.cc @@ -674,22 +674,37 @@ std::string ResourceIdSet::ToString() const { return return_string; } -std::vector ResourceIdSet::ToProtobuf() const { - std::vector resources; +std::vector> ResourceIdSet::ToFlatbuf( + flatbuffers::FlatBufferBuilder &fbb) const { + std::vector> return_message; for (auto const &resource_pair : available_resources_) { - rpc::ResourceIdSetInfo resource_id_set_info; - resource_id_set_info.set_resource_name(resource_pair.first); + std::vector resource_ids; + std::vector resource_fractions; for (auto whole_id : resource_pair.second.WholeIds()) { - resource_id_set_info.add_resource_ids(whole_id); - resource_id_set_info.add_resource_fractions(1); + resource_ids.push_back(whole_id); + resource_fractions.push_back(1); } + for (auto const &fractional_pair : resource_pair.second.FractionalIds()) { - resource_id_set_info.add_resource_ids(fractional_pair.first); - resource_id_set_info.add_resource_fractions(fractional_pair.second.ToDouble()); + resource_ids.push_back(fractional_pair.first); + resource_fractions.push_back(fractional_pair.second.ToDouble()); } - resources.emplace_back(resource_id_set_info); + + auto resource_id_set_message = protocol::CreateResourceIdSetInfo( + fbb, fbb.CreateString(resource_pair.first), fbb.CreateVector(resource_ids), + fbb.CreateVector(resource_fractions)); + + return_message.push_back(resource_id_set_message); } - return resources; + + return return_message; +} + +const std::string ResourceIdSet::Serialize() const { + flatbuffers::FlatBufferBuilder fbb; + auto resource_id_set_flatbuf = ToFlatbuf(fbb); + fbb.Finish(fbb.CreateVector(resource_id_set_flatbuf)); + return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize()); } /// SchedulingResources class implementation diff --git a/src/ray/common/task/scheduling_resources.h b/src/ray/common/task/scheduling_resources.h index 7076fc408..4c5a7121a 100644 --- a/src/ray/common/task/scheduling_resources.h +++ b/src/ray/common/task/scheduling_resources.h @@ -6,7 +6,7 @@ #include #include -#include "ray/protobuf/common.pb.h" +#include "ray/raylet/format/node_manager_generated.h" namespace ray { @@ -422,10 +422,18 @@ class ResourceIdSet { /// \return A human-readable string version of the object. std::string ToString() const; - /// \brief Convert this object to a vector of protobuf `ResourceIdSetInfo`s. + /// \brief Serialize this object using flatbuffers. /// - /// \return A vector inclusing resource id set infos. - std::vector ToProtobuf() const; + /// \param fbb A flatbuffer builder object. + /// \return A flatbuffer serialized version of this object. + std::vector> ToFlatbuf( + flatbuffers::FlatBufferBuilder &fbb) const; + + /// \brief Serialize this object as a string. + /// + /// \return A serialized string of this object. + /// TODO(zhijunfu): this can be removed after raylet client is migrated to grpc. + const std::string Serialize() const; private: /// A mapping from resource name to a set of resource IDs for that resource. diff --git a/src/ray/common/task/task.h b/src/ray/common/task/task.h index aba8e0eb3..4c37ebdd4 100644 --- a/src/ray/common/task/task.h +++ b/src/ray/common/task/task.h @@ -3,9 +3,9 @@ #include +#include "ray/common/task/task_common.h" #include "ray/common/task/task_execution_spec.h" #include "ray/common/task/task_spec.h" -#include "ray/protobuf/common.pb.h" namespace ray { diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index 09b8ee61b..aeb433065 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -6,7 +6,7 @@ #include "ray/common/buffer.h" #include "ray/common/id.h" #include "ray/common/task/task_spec.h" -#include "ray/rpc/raylet/raylet_client.h" +#include "ray/raylet/raylet_client.h" #include "ray/util/util.h" namespace ray { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 218881002..249e78fe4 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -8,12 +8,10 @@ #include "ray/core_worker/task_execution.h" #include "ray/core_worker/task_interface.h" #include "ray/gcs/redis_gcs_client.h" -#include "ray/rpc/raylet/raylet_client.h" +#include "ray/raylet/raylet_client.h" namespace ray { -using rpc::RayletClient; - /// The root class that contains all the core and language-independent functionalities /// of the worker. This class is supposed to be used to implement app-language (Java, /// Python, etc) workers. diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc b/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc index 56f0e94ec..e84e4c51e 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc @@ -4,9 +4,9 @@ #include "ray/core_worker/common.h" #include "ray/core_worker/core_worker.h" #include "ray/core_worker/lib/java/jni_utils.h" -#include "ray/rpc/raylet/raylet_client.h" +#include "ray/raylet/raylet_client.h" -inline ray::RayletClient &GetRayletClientFromPointer(jlong nativeCoreWorkerPointer) { +inline RayletClient &GetRayletClientFromPointer(jlong nativeCoreWorkerPointer) { return reinterpret_cast(nativeCoreWorkerPointer)->GetRayletClient(); } diff --git a/src/ray/core_worker/object_interface.h b/src/ray/core_worker/object_interface.h index d5b63a6aa..bf8f71ecf 100644 --- a/src/ray/core_worker/object_interface.h +++ b/src/ray/core_worker/object_interface.h @@ -11,8 +11,6 @@ namespace ray { -using rpc::RayletClient; - class CoreWorker; class CoreWorkerStoreProvider; diff --git a/src/ray/core_worker/store_provider/local_plasma_provider.h b/src/ray/core_worker/store_provider/local_plasma_provider.h index 912709c46..d67b916b9 100644 --- a/src/ray/core_worker/store_provider/local_plasma_provider.h +++ b/src/ray/core_worker/store_provider/local_plasma_provider.h @@ -7,7 +7,7 @@ #include "ray/common/status.h" #include "ray/core_worker/common.h" #include "ray/core_worker/store_provider/store_provider.h" -#include "ray/rpc/raylet/raylet_client.h" +#include "ray/raylet/raylet_client.h" namespace ray { diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index 797b71834..1ef6a8b1d 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -8,12 +8,10 @@ #include "ray/core_worker/common.h" #include "ray/core_worker/store_provider/local_plasma_provider.h" #include "ray/core_worker/store_provider/store_provider.h" -#include "ray/rpc/raylet/raylet_client.h" +#include "ray/raylet/raylet_client.h" namespace ray { -using rpc::RayletClient; - class CoreWorker; /// The class provides implementations for accessing plasma store, which includes both diff --git a/src/ray/core_worker/task_interface.h b/src/ray/core_worker/task_interface.h index bca4011d0..35f481b41 100644 --- a/src/ray/core_worker/task_interface.h +++ b/src/ray/core_worker/task_interface.h @@ -17,8 +17,6 @@ namespace ray { -using rpc::RayletClient; - class CoreWorker; /// Options of a non-actor-creation task. diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index d18fc0969..6ccd62a3e 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -10,7 +10,7 @@ #include "ray/core_worker/store_provider/local_plasma_provider.h" #include "ray/core_worker/store_provider/memory_store_provider.h" -#include "ray/rpc/raylet/raylet_client.h" +#include "ray/raylet/raylet_client.h" #include "src/ray/protobuf/direct_actor.grpc.pb.h" #include "src/ray/protobuf/direct_actor.pb.h" #include "src/ray/util/test_util.h" diff --git a/src/ray/core_worker/transport/raylet_transport.h b/src/ray/core_worker/transport/raylet_transport.h index 10ed146c0..0ba8feb5e 100644 --- a/src/ray/core_worker/transport/raylet_transport.h +++ b/src/ray/core_worker/transport/raylet_transport.h @@ -5,13 +5,11 @@ #include "ray/core_worker/object_interface.h" #include "ray/core_worker/transport/transport.h" -#include "ray/rpc/raylet/raylet_client.h" +#include "ray/raylet/raylet_client.h" #include "ray/rpc/worker/worker_server.h" namespace ray { -using rpc::RayletClient; - /// In raylet task submitter and receiver, a task is submitted to raylet, and possibly /// gets forwarded to another raylet on which node the task should be executed, and /// then a worker on that node gets this task and starts executing it. diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index 1f7ef100c..84eaaf250 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -769,8 +769,9 @@ void ObjectManager::SpreadFreeObjectsRequest( const std::vector> &rpc_clients) { // This code path should be called from node manager. rpc::FreeObjectsRequest free_objects_request; - IdVectorToProtobuf( - object_ids, free_objects_request, &rpc::FreeObjectsRequest::add_object_ids); + for (const auto &e : object_ids) { + free_objects_request.add_object_ids(e.Binary()); + } for (auto &rpc_client : rpc_clients) { rpc_client->FreeObjects(free_objects_request, [](const Status &status, diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 46f740f40..badcbea5b 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -11,18 +11,6 @@ enum Language { CPP = 2; } -// Resource id set info. -message ResourceIdSetInfo { - // The name of the resource. - bytes resource_name = 1; - // The resource IDs reserved for this worker. - repeated uint64 resource_ids = 2; - // The fraction of each resource ID that is reserved for this worker. Note - // that the length of this list must be the same as the length of - // resource_ids. - repeated double resource_fractions = 3; -} - // Type of a worker. enum WorkerType { WORKER = 0; diff --git a/src/ray/protobuf/raylet.proto b/src/ray/protobuf/raylet.proto deleted file mode 100644 index ee71ebcec..000000000 --- a/src/ray/protobuf/raylet.proto +++ /dev/null @@ -1,215 +0,0 @@ -syntax = "proto3"; - -package ray.rpc; - -import "src/ray/protobuf/common.proto"; -import "src/ray/protobuf/gcs.proto"; - -/// NOTE(Joey Jiang) Every request defined in this file should have a `worker_id` field, -/// which will be used in `NodeManager::PreprocessRequest`. - -/// Service request and reply messages. -message RegisterClientRequest { - // The worker id. - bytes worker_id = 1; - // Indicates the client is a worker or a driver. - bool is_worker = 2; - // The process ID of this worker. - uint32 worker_pid = 3; - // The job ID. - bytes job_id = 4; - // Language of this worker. - Language language = 5; - // Port that this worker is listening on. - // If port > 0, then worker will listen to this port and wait for - // raylet to push tasks, instead of invoking GetTask(). - int32 port = 6; -} -message RegisterClientReply { - repeated int32 gpu_ids = 1; -} - -message SubmitTaskRequest { - bytes worker_id = 1; - TaskSpec task_spec = 2; -} -message SubmitTaskReply { -} - -message DisconnectClientRequest { - bytes worker_id = 1; -} -message DisconnectClientReply { -} - -message GetTaskRequest { - bytes worker_id = 1; -} -message GetTaskReply { - // A string of bytes representing the task specification. - bytes task_spec = 1; - // A list of the resources reserved for this worker. - repeated ResourceIdSetInfo fractional_resource_ids = 2; -} - -message TaskDoneRequest { - bytes worker_id = 1; -} -message TaskDoneReply { -} - -message FetchOrReconstructRequest { - // The worker ID. - bytes worker_id = 1; - // List of object IDs of the objects that we want to reconstruct or fetch. - repeated bytes object_ids = 2; - // Indicates that we only want to fetch objects, not reconstruct them. - bool fetch_only = 3; - // The current task ID. If fetch_only is false, then this task is blocked. - bytes task_id = 4; -} -message FetchOrReconstructReply { -} - -message NotifyUnblockedRequest { - bytes worker_id = 1; - // The current task ID. This task is no longer blocked. - bytes task_id = 2; -} -message NotifyUnblockedReply { -} - -message WaitRequest { - // The worker ID. - bytes worker_id = 1; - // List of object ids we'll be waiting on. - repeated bytes object_ids = 2; - // Number of objects expected to be returned, if available. - uint64 num_ready_objects = 3; - // Timeout in milliseconds. - int64 timeout = 4; - // Whether to wait until objects appear locally. - bool wait_local = 5; - // The current task ID. If there are less than num_ready_objects local, then - // this task is blocked. - bytes task_id = 6; -} -message WaitReply { - // List of object ids found. - repeated bytes found = 1; - // List of object ids not found. - repeated bytes remaining = 2; -} - -message PushErrorRequest { - // The worker ID. - bytes worker_id = 1; - // The job id that the error is for. - bytes job_id = 2; - // The type of the error. - bytes type = 3; - // The error message. - bytes error_message = 4; - // The timestamp of the error message. - double timestamp = 5; -} -message PushErrorReply { -} - -message PushProfileEventsRequest { - bytes worker_id = 1; - ProfileTableData profile_table_data = 2; -} -message PushProfileEventsReply { -} - -message FreeObjectsInStoreRequest { - // The worker ID. - bytes worker_id = 1; - // Whether keep this request within the local object store - // or send it to all of the object stores. - bool local_only = 2; - // Whether also delete objects' creating tasks from GCS. - bool delete_creating_tasks = 3; - // List of object ids to delete from the object store. - repeated bytes object_ids = 4; -} -message FreeObjectsInStoreReply { -} - -message PrepareActorCheckpointRequest { - bytes worker_id = 1; - bytes actor_id = 2; -} -message PrepareActorCheckpointReply { - bytes worker_id = 1; - bytes checkpoint_id = 2; -} - -message NotifyActorResumedFromCheckpointRequest { - bytes worker_id = 1; - // ID of the actor that resumed. - bytes actor_id = 2; - // ID of the checkpoint from which the actor was resumed. - bytes checkpoint_id = 3; -} -message NotifyActorResumedFromCheckpointReply { -} - -message SetResourceRequest { - bytes worker_id = 1; - // Name of the resource to be set. - bytes resource_name = 2; - // Capacity of the resource to be set. - double capacity = 3; - // Client ID where this resource will be set. - bytes client_id = 4; -} -message SetResourceReply { -} - -message HeartbeatRequest { - bytes worker_id = 1; - bool is_worker = 2; -} -message HeartbeatReply { -} - -/// Worker-to-raylet RPC service interface. -service RayletService { - // Register a new worker to the raylet. - rpc RegisterClient(RegisterClientRequest) returns (RegisterClientReply); - // Submit a task to the raylet. - rpc SubmitTask(SubmitTaskRequest) returns (SubmitTaskReply); - // Disconnect this client from raylet gracefully. - rpc DisconnectClient(DisconnectClientRequest) returns (DisconnectClientReply); - // Get a new task from the raylet. - rpc GetTask(GetTaskRequest) returns (GetTaskReply); - // Notify the raylet that a task is finished. - rpc TaskDone(TaskDoneRequest) returns (TaskDoneReply); - // Reconstruct or fetch possibly lost objects. - rpc FetchOrReconstruct(FetchOrReconstructRequest) returns (FetchOrReconstructReply); - // For a worker that was blocked on some object(s), tell the raylet - // that the worker is now unblocked. - rpc NotifyUnblocked(NotifyUnblockedRequest) returns (NotifyUnblockedReply); - // Wait for objects to be ready either from local or remote plasma stores. - // The `WaitReply` contains the objects found and objects remaining. - rpc Wait(WaitRequest) returns (WaitReply); - // Push an error to the relevant driver. - rpc PushError(PushErrorRequest) returns (PushErrorReply); - // Push some profiling events to the GCS. When sending this message to the - // node manager, the message itself is serialized as a ProfileTableData object. - rpc PushProfileEvents(PushProfileEventsRequest) returns (PushProfileEventsReply); - // Free the objects in plasma objects store. - rpc FreeObjectsInStore(FreeObjectsInStoreRequest) returns (FreeObjectsInStoreReply); - // Request raylet backend to prepare a checkpoint for an actor. - rpc PrepareActorCheckpoint(PrepareActorCheckpointRequest) - returns (PrepareActorCheckpointReply); - // Notify raylet backend that an actor was resumed from a checkpoint. - rpc NotifyActorResumedFromCheckpoint(NotifyActorResumedFromCheckpointRequest) - returns (NotifyActorResumedFromCheckpointReply); - // Set dynamic custom resource. - rpc SetResource(SetResourceRequest) returns (SetResourceReply); - // Send a heartbeat message to raylet. - rpc Heartbeat(HeartbeatRequest) returns (HeartbeatReply); -} diff --git a/src/ray/protobuf/worker.proto b/src/ray/protobuf/worker.proto index 3e2f83055..3c2d30ab6 100644 --- a/src/ray/protobuf/worker.proto +++ b/src/ray/protobuf/worker.proto @@ -8,7 +8,9 @@ message AssignTaskRequest { // The task to be pushed. Task task = 1; // A list of the resources reserved for this worker. - repeated ResourceIdSetInfo resource_ids = 2; + // TODO(zhijunfu): `resource_ids` is represented as + // flatbutters-serialized bytes, will be moved to protobuf later. + bytes resource_ids = 2; } message AssignTaskReply { diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc new file mode 100644 index 000000000..a9ef670b9 --- /dev/null +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -0,0 +1,292 @@ +#include "ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.h" + +#include + +#include "ray/common/id.h" +#include "ray/core_worker/lib/java/jni_utils.h" +#include "ray/raylet/raylet_client.h" +#include "ray/util/logging.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeInit + * Signature: (Ljava/lang/String;[BZ[B)J + */ +JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit( + JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker, + jbyteArray jobId) { + const auto worker_id = JavaByteArrayToId(env, workerId); + const auto job_id = JavaByteArrayToId(env, jobId); + const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE); + auto raylet_client = new std::unique_ptr( + new RayletClient(nativeString, worker_id, isWorker, job_id, Language::JAVA)); + env->ReleaseStringUTFChars(sockName, nativeString); + return reinterpret_cast(raylet_client); +} + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeSubmitTask + * Signature: (J[BLjava/nio/ByteBuffer;II)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSubmitTask( + JNIEnv *env, jclass, jlong client, jbyteArray taskSpec) { + auto &raylet_client = *reinterpret_cast *>(client); + + jbyte *data = env->GetByteArrayElements(taskSpec, NULL); + jsize size = env->GetArrayLength(taskSpec); + ray::rpc::TaskSpec task_spec_message; + task_spec_message.ParseFromArray(data, size); + env->ReleaseByteArrayElements(taskSpec, data, JNI_ABORT); + + ray::TaskSpecification task_spec(task_spec_message); + auto status = raylet_client->SubmitTask(task_spec); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); +} + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeGetTask + * Signature: (J)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeGetTask( + JNIEnv *env, jclass, jlong client) { + auto &raylet_client = *reinterpret_cast *>(client); + + std::unique_ptr spec; + auto status = raylet_client->GetTask(&spec); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); + + // Serialize the task spec and copy to Java byte array. + auto task_data = spec->Serialize(); + + jbyteArray result = env->NewByteArray(task_data.size()); + if (result == nullptr) { + return nullptr; /* out of memory error thrown */ + } + + env->SetByteArrayRegion(result, 0, task_data.size(), + reinterpret_cast(task_data.data())); + + return result; +} + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeDestroy + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeDestroy( + JNIEnv *env, jclass, jlong client) { + auto raylet_client = reinterpret_cast *>(client); + auto status = (*raylet_client)->Disconnect(); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); + delete raylet_client; +} + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeWaitObject + * Signature: (J[[BIIZ[B)[Z + */ +JNIEXPORT jbooleanArray JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject( + JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jint numReturns, + jint timeoutMillis, jboolean isWaitLocal, jbyteArray currentTaskId) { + std::vector object_ids; + auto len = env->GetArrayLength(objectIds); + for (int i = 0; i < len; i++) { + jbyteArray object_id_bytes = + static_cast(env->GetObjectArrayElement(objectIds, i)); + const auto object_id = JavaByteArrayToId(env, object_id_bytes); + object_ids.push_back(object_id); + env->DeleteLocalRef(object_id_bytes); + } + const auto current_task_id = JavaByteArrayToId(env, currentTaskId); + + auto &raylet_client = *reinterpret_cast *>(client); + + // Invoke wait. + WaitResultPair result; + auto status = + raylet_client->Wait(object_ids, numReturns, timeoutMillis, + static_cast(isWaitLocal), current_task_id, &result); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); + + // Convert result to java object. + jboolean put_value = true; + jbooleanArray resultArray = env->NewBooleanArray(object_ids.size()); + for (uint i = 0; i < result.first.size(); ++i) { + for (uint j = 0; j < object_ids.size(); ++j) { + if (result.first[i] == object_ids[j]) { + env->SetBooleanArrayRegion(resultArray, j, 1, &put_value); + break; + } + } + } + + put_value = false; + for (uint i = 0; i < result.second.size(); ++i) { + for (uint j = 0; j < object_ids.size(); ++j) { + if (result.second[i] == object_ids[j]) { + env->SetBooleanArrayRegion(resultArray, j, 1, &put_value); + break; + } + } + } + return resultArray; +} + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeGenerateActorCreationTaskId + * Signature: ([B[BI)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorCreationTaskId( + JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, + jint parent_task_counter) { + const auto job_id = JavaByteArrayToId(env, jobId); + const auto parent_task_id = JavaByteArrayToId(env, parentTaskId); + + const ActorID actor_id = ray::ActorID::Of(job_id, parent_task_id, parent_task_counter); + const TaskID actor_creation_task_id = ray::TaskID::ForActorCreationTask(actor_id); + jbyteArray result = env->NewByteArray(actor_creation_task_id.Size()); + if (nullptr == result) { + return nullptr; + } + env->SetByteArrayRegion(result, 0, actor_creation_task_id.Size(), + reinterpret_cast(actor_creation_task_id.Data())); + return result; +} + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeGenerateActorTaskId + * Signature: ([B[BI[B)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateActorTaskId( + JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, + jint parent_task_counter, jbyteArray actorId) { + const auto job_id = JavaByteArrayToId(env, jobId); + const auto parent_task_id = JavaByteArrayToId(env, parentTaskId); + const auto actor_id = JavaByteArrayToId(env, actorId); + const TaskID actor_task_id = + ray::TaskID::ForActorTask(job_id, parent_task_id, parent_task_counter, actor_id); + + jbyteArray result = env->NewByteArray(actor_task_id.Size()); + if (nullptr == result) { + return nullptr; + } + env->SetByteArrayRegion(result, 0, actor_task_id.Size(), + reinterpret_cast(actor_task_id.Data())); + return result; +} + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeGenerateNormalTaskId + * Signature: ([B[BI)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateNormalTaskId( + JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, + jint parent_task_counter) { + const auto job_id = JavaByteArrayToId(env, jobId); + const auto parent_task_id = JavaByteArrayToId(env, parentTaskId); + const TaskID task_id = + ray::TaskID::ForNormalTask(job_id, parent_task_id, parent_task_counter); + + jbyteArray result = env->NewByteArray(task_id.Size()); + if (nullptr == result) { + return nullptr; + } + env->SetByteArrayRegion(result, 0, task_id.Size(), + reinterpret_cast(task_id.Data())); + return result; +} + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeFreePlasmaObjects + * Signature: (J[[BZZ)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativeFreePlasmaObjects( + JNIEnv *env, jclass, jlong client, jobjectArray objectIds, jboolean localOnly, + jboolean deleteCreatingTasks) { + std::vector object_ids; + auto len = env->GetArrayLength(objectIds); + for (int i = 0; i < len; i++) { + jbyteArray object_id_bytes = + static_cast(env->GetObjectArrayElement(objectIds, i)); + const auto object_id = JavaByteArrayToId(env, object_id_bytes); + object_ids.push_back(object_id); + env->DeleteLocalRef(object_id_bytes); + } + auto &raylet_client = *reinterpret_cast *>(client); + auto status = raylet_client->FreeObjects(object_ids, localOnly, deleteCreatingTasks); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); +} + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativePrepareCheckpoint + * Signature: (J[B)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativePrepareCheckpoint(JNIEnv *env, jclass, + jlong client, + jbyteArray actorId) { + auto &raylet_client = *reinterpret_cast *>(client); + const auto actor_id = JavaByteArrayToId(env, actorId); + ActorCheckpointID checkpoint_id; + auto status = raylet_client->PrepareActorCheckpoint(actor_id, checkpoint_id); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); + jbyteArray result = env->NewByteArray(checkpoint_id.Size()); + env->SetByteArrayRegion(result, 0, checkpoint_id.Size(), + reinterpret_cast(checkpoint_id.Data())); + return result; +} + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeNotifyActorResumedFromCheckpoint + * Signature: (J[B[B)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_raylet_RayletClientImpl_nativeNotifyActorResumedFromCheckpoint( + JNIEnv *env, jclass, jlong client, jbyteArray actorId, jbyteArray checkpointId) { + auto &raylet_client = *reinterpret_cast *>(client); + const auto actor_id = JavaByteArrayToId(env, actorId); + const auto checkpoint_id = JavaByteArrayToId(env, checkpointId); + auto status = raylet_client->NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); +} + +/* + * Class: org_ray_runtime_raylet_RayletClientImpl + * Method: nativeSetResource + * Signature: (JLjava/lang/String;D[B)V + */ +JNIEXPORT void JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeSetResource( + JNIEnv *env, jclass, jlong client, jstring resourceName, jdouble capacity, + jbyteArray nodeId) { + auto &raylet_client = *reinterpret_cast *>(client); + const auto node_id = JavaByteArrayToId(env, nodeId); + const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE); + + auto status = raylet_client->SetResource(native_resource_name, + static_cast(capacity), node_id); + env->ReleaseStringUTFChars(resourceName, native_resource_name); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); +} + +#ifdef __cplusplus +} +#endif diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index e9591f755..1eb6c3cf4 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -7,6 +7,7 @@ #include "ray/common/task/task_execution_spec.h" #include "ray/common/task/task_spec.h" #include "ray/common/task/task_util.h" +#include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/lineage_cache.h" #include "ray/util/test_util.h" diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 82ce87f44..ee7c0ece6 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -1,12 +1,12 @@ #include "ray/raylet/node_manager.h" #include -#include #include "ray/common/status.h" #include "ray/common/common_protocol.h" #include "ray/common/id.h" +#include "ray/raylet/format/node_manager_generated.h" #include "ray/stats/stats.h" namespace { @@ -14,20 +14,6 @@ namespace { #define RAY_CHECK_ENUM(x, y) \ static_assert(static_cast(x) == static_cast(y), "protocol mismatch") -/// Macro to handle early return for preprocessing. -/// An early return will take place if the worker is being killed due to the exiting of -/// driver, or the worker is not registered yet. -#define PREPROCESS_WORKER_REQUEST(REQUEST_TYPE, REQUEST, SEND_REPLY) \ - do { \ - WorkerID worker_id = WorkerID::FromBinary(REQUEST.worker_id()); \ - if (!PreprocessRequest(worker_id, #REQUEST_TYPE)) { \ - SEND_REPLY( \ - Status::Invalid("Discard this request due to failure of preprocessing."), \ - nullptr, nullptr); \ - return; \ - } \ - } while (0) - /// A helper function to return the expected actor counter for a given actor /// and actor handle, according to the given actor registry. If a task's /// counter is less than the returned value, then the task is a duplicate. If @@ -90,7 +76,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, gcs_client_(std::move(gcs_client)), object_directory_(std::move(object_directory)), heartbeat_timer_(io_service), - heartbeat_period_(config.heartbeat_period_ms), + heartbeat_period_(std::chrono::milliseconds(config.heartbeat_period_ms)), debug_dump_period_(config.debug_dump_period_ms), temp_dir_(config.temp_dir), object_manager_profile_timer_(io_service), @@ -119,7 +105,6 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, actor_registry_(), node_manager_server_("NodeManager", config.node_manager_port), node_manager_service_(io_service, *this), - raylet_service_(io_service, *this), client_call_manager_(io_service) { RAY_CHECK(heartbeat_period_.count() > 0); // Initialize the resource map with own cluster resource configuration. @@ -263,7 +248,6 @@ void NodeManager::KillWorker(std::shared_ptr worker) { // up its state before force killing. The client socket will be closed // and the worker struct will be freed after the timeout. kill(worker->Pid(), SIGTERM); - worker->MarkAsBeingKilled(); auto retry_timer = std::make_shared(io_service_); auto retry_duration = boost::posix_time::milliseconds( @@ -287,10 +271,13 @@ void NodeManager::HandleJobTableUpdate(const JobID &id, auto workers = worker_pool_.GetWorkersRunningTasksForJob(job_id); // Kill all the workers. The actual cleanup for these workers is done - // later when the worker heartbeats timeout. + // later when we receive the DisconnectClient message from them. for (const auto &worker : workers) { // Clean up any open ray.wait calls that the worker made. task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId()); + // Mark the worker as dead so further messages from it are ignored + // (except DisconnectClient). + worker->MarkDead(); // Then kill the worker process. KillWorker(worker); } @@ -352,24 +339,9 @@ void NodeManager::Heartbeat() { last_debug_dump_at_ms_ = now_ms; } - // Check worker heartbeat timeout times. - std::vector> dead_workers; - worker_pool_.TickHeartbeatTimer(RayConfig::instance().num_worker_heartbeats_timeout(), - &dead_workers); - if (!dead_workers.empty()) { - for (const auto &worker : dead_workers) { - RAY_LOG(INFO) << "Worker " << worker->GetWorkerId() - << " dead because of timeout, pid: " << worker->Pid(); - ProcessDisconnectClientMessage(worker->GetWorkerId(), worker->IsBeingKilled()); - } - } - // Reset the timer. heartbeat_timer_.expires_from_now(heartbeat_period_); heartbeat_timer_.async_wait([this](const boost::system::error_code &error) { - if (error == boost::asio::error::operation_aborted) { - return; - } RAY_CHECK(!error); Heartbeat(); }); @@ -709,12 +681,9 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, } } -void NodeManager::CleanUpTasksForFinishedJob(const JobID &job_id) { - auto tasks_to_remove = local_queues_.GetTaskIdsForJob(job_id); - task_dependency_manager_.RemoveTasksAndRelatedObjects(tasks_to_remove); - // NOTE(swang): SchedulingQueue::RemoveTasks modifies its argument so we must - // call it last. - local_queues_.RemoveTasks(tasks_to_remove); +void NodeManager::ProcessNewClient(LocalClientConnection &client) { + // The new client is a worker, so begin listening for messages. + client.ProcessMessages(); } // A helper function to create a mapping from resource shapes to @@ -731,7 +700,7 @@ std::unordered_map> MakeTasksWithResources( void NodeManager::DispatchTasks( const std::unordered_map> &tasks_with_resources) { - std::unordered_set assigned_task_ids; + std::unordered_set removed_task_ids; for (const auto &it : tasks_with_resources) { const auto &task_resources = it.first; for (const auto &task_id : it.second) { @@ -742,85 +711,149 @@ void NodeManager::DispatchTasks( break; } if (AssignTask(task)) { - assigned_task_ids.insert(task_id); + removed_task_ids.insert(task_id); } } } - - // Move the ASSIGNED task to the RUNNING queue. - // We should move task outside `AssignTask` function because removing - // task might influence the iterator. - local_queues_.MoveTasks(assigned_task_ids, TaskState::READY, TaskState::RUNNING); + // Move the ASSIGNED task to the SWAP queue so that we remember that we have + // it queued locally. Once the GetTaskReply has been sent, the task will get + // re-queued, depending on whether the message succeeded or not. + local_queues_.MoveTasks(removed_task_ids, TaskState::READY, TaskState::SWAP); } -bool NodeManager::PreprocessRequest(const WorkerID &worker_id, - const std::string &request_name) { - std::ostringstream oss; - if (RAY_LOG_ENABLED(DEBUG)) { - oss << "Received a " << request_name << " request. Worker id " << worker_id << "."; +void NodeManager::ProcessClientMessage( + const std::shared_ptr &client, int64_t message_type, + const uint8_t *message_data) { + auto registered_worker = worker_pool_.GetRegisteredWorker(client); + auto message_type_value = static_cast(message_type); + RAY_LOG(DEBUG) << "[Worker] Message " + << protocol::EnumNameMessageType(message_type_value) << "(" + << message_type << ") from worker with PID " + << (registered_worker ? std::to_string(registered_worker->Pid()) + : "nil"); + if (registered_worker && registered_worker->IsDead()) { + // For a worker that is marked as dead (because the job has died already), + // all the messages are ignored except DisconnectClient. + if ((message_type_value != protocol::MessageType::DisconnectClient) && + (message_type_value != protocol::MessageType::IntentionalDisconnectClient)) { + // Listen for more messages. + client->ProcessMessages(); + return; + } } - auto worker = worker_pool_.GetWorker(worker_id); - // Worker process has been killed, we should discard this request. - if (!worker) { - RAY_LOG(WARNING) << "Worker " << worker_id << " is not found in worker pool, request " - << request_name << " will be discarded."; - return false; - } - if (RAY_LOG_ENABLED(DEBUG)) { - oss << " Is worker: " << (worker->IsWorker() ? "true" : "false") << ". Worker pid " - << std::to_string(worker->Pid()) << "."; - RAY_LOG(DEBUG) << oss.str(); + switch (message_type_value) { + case protocol::MessageType::RegisterClientRequest: { + ProcessRegisterClientRequestMessage(client, message_data); + } break; + case protocol::MessageType::GetTask: { + RAY_CHECK(!registered_worker->UsePush()); + HandleWorkerAvailable(client); + } break; + case protocol::MessageType::TaskDone: { + RAY_CHECK(registered_worker->UsePush()); + HandleWorkerAvailable(client); + } break; + case protocol::MessageType::DisconnectClient: { + ProcessDisconnectClientMessage(client); + // We don't need to receive future messages from this client, + // because it's already disconnected. + return; + } break; + case protocol::MessageType::IntentionalDisconnectClient: { + ProcessDisconnectClientMessage(client, /* intentional_disconnect = */ true); + // We don't need to receive future messages from this client, + // because it's already disconnected. + return; + } break; + case protocol::MessageType::SubmitTask: { + ProcessSubmitTaskMessage(message_data); + } break; + case protocol::MessageType::SetResourceRequest: { + ProcessSetResourceRequest(client, message_data); + } break; + case protocol::MessageType::FetchOrReconstruct: { + ProcessFetchOrReconstructMessage(client, message_data); + } break; + case protocol::MessageType::NotifyUnblocked: { + auto message = flatbuffers::GetRoot(message_data); + HandleTaskUnblocked(client, from_flatbuf(*message->task_id())); + } break; + case protocol::MessageType::WaitRequest: { + ProcessWaitRequestMessage(client, message_data); + } break; + case protocol::MessageType::PushErrorRequest: { + ProcessPushErrorRequestMessage(message_data); + } break; + case protocol::MessageType::PushProfileEventsRequest: { + auto fbs_message = flatbuffers::GetRoot(message_data); + rpc::ProfileTableData profile_table_data; + RAY_CHECK( + profile_table_data.ParseFromArray(fbs_message->data(), fbs_message->size())); + RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_table_data)); + } break; + case protocol::MessageType::FreeObjectsInObjectStoreRequest: { + auto message = flatbuffers::GetRoot(message_data); + std::vector object_ids = from_flatbuf(*message->object_ids()); + // Clean up objects from the object store. + object_manager_.FreeObjects(object_ids, message->local_only()); + if (message->delete_creating_tasks()) { + // Clean up their creating tasks from GCS. + std::vector creating_task_ids; + for (const auto &object_id : object_ids) { + creating_task_ids.push_back(object_id.TaskId()); + } + gcs_client_->raylet_task_table().Delete(JobID::Nil(), creating_task_ids); + } + } break; + case protocol::MessageType::PrepareActorCheckpointRequest: { + ProcessPrepareActorCheckpointRequest(client, message_data); + } break; + case protocol::MessageType::NotifyActorResumedFromCheckpoint: { + ProcessNotifyActorResumedFromCheckpoint(message_data); + } break; + + default: + RAY_LOG(FATAL) << "Received unexpected message type " << message_type; } - // The worker process is being killing, we should discard this request. - if (worker->IsBeingKilled()) { - RAY_LOG(INFO) << "Worker " << worker_id << " is being killed, request " - << request_name << " will be discarded."; - return false; - } - - return true; + // Listen for more messages. + client->ProcessMessages(); } -void NodeManager::HandleRegisterClientRequest( - const rpc::RegisterClientRequest &request, rpc::RegisterClientReply *reply, - rpc::SendReplyCallback send_reply_callback) { - // Client id in register client is treated as worker id. - const WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); - bool is_worker = request.is_worker(); - auto worker = - std::make_shared(worker_id, request.worker_pid(), request.language(), - request.port(), client_call_manager_, is_worker); - - RAY_LOG(DEBUG) << "Received a RegisterClientRequest. Worker id: " << worker_id - << ". Is worker: " << is_worker << ". Worker pid " - << request.worker_pid(); - +void NodeManager::ProcessRegisterClientRequestMessage( + const std::shared_ptr &client, const uint8_t *message_data) { + client->Register(); + auto message = flatbuffers::GetRoot(message_data); + Language language = static_cast(message->language()); + WorkerID worker_id = from_flatbuf(*message->worker_id()); + auto worker = std::make_shared(worker_id, message->worker_pid(), language, + message->port(), client, client_call_manager_); Status status; - if (is_worker) { + if (message->is_worker()) { // Register the new worker. bool use_push_task = worker->UsePush(); - status = worker_pool_.RegisterWorker(worker_id, std::move(worker)); + auto connection = worker->Connection(); + status = worker_pool_.RegisterWorker(std::move(worker)); if (status.ok() && use_push_task) { // only call `HandleWorkerAvailable` when push mode is used. - HandleWorkerAvailable(worker_id); + HandleWorkerAvailable(connection); } } else { // Register the new driver. - auto driver_task_id = TaskID::ComputeDriverTaskId(worker_id); - auto job_id = JobID::FromBinary(request.job_id()); + const JobID job_id = from_flatbuf(*message->job_id()); + // Compute a dummy driver task id from a given driver. + const TaskID driver_task_id = TaskID::ComputeDriverTaskId(worker_id); worker->AssignTaskId(driver_task_id); worker->AssignJobId(job_id); - status = worker_pool_.RegisterDriver(worker_id, std::move(worker)); + status = worker_pool_.RegisterDriver(std::move(worker)); if (status.ok()) { local_queues_.AddDriverTaskId(driver_task_id); RAY_CHECK_OK(gcs_client_->job_table().AppendJobData( job_id, /*is_dead=*/false, std::time(nullptr), - initial_config_.node_manager_address, request.worker_pid())); + initial_config_.node_manager_address, message->worker_pid())); } } - send_reply_callback(status, nullptr, nullptr); } void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_local, @@ -867,78 +900,83 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca RAY_CHECK_OK(gcs_client_->Actors().AsyncUpdate(actor_id, actor_notification, done)); } -void NodeManager::HandleGetTaskRequest(const rpc::GetTaskRequest &request, - rpc::GetTaskReply *reply, - rpc::SendReplyCallback send_reply_callback) { - PREPROCESS_WORKER_REQUEST(GetTaskRequest, request, send_reply_callback); - WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(worker_id); - - RAY_CHECK(!worker->UsePush()); - // Reply would be sent when assigned a task to the worker successfully. - worker->SetGetTaskReplyAndCallback(reply, std::move(send_reply_callback)); - HandleWorkerAvailable(worker_id); -} - -void NodeManager::HandleTaskDoneRequest(const rpc::TaskDoneRequest &request, - rpc::TaskDoneReply *reply, - rpc::SendReplyCallback send_reply_callback) { - PREPROCESS_WORKER_REQUEST(TaskDoneRequest, request, send_reply_callback); - WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); - auto worker = worker_pool_.GetRegisteredWorker(worker_id); - RAY_CHECK(worker && worker->UsePush()); - HandleWorkerAvailable(worker_id); - send_reply_callback(Status::OK(), nullptr, nullptr); -} - -void NodeManager::HandleDisconnectClientRequest( - const rpc::DisconnectClientRequest &request, rpc::DisconnectClientReply *reply, - rpc::SendReplyCallback send_reply_callback) { - const WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); - ProcessDisconnectClientMessage(worker_id, true); - send_reply_callback(Status::OK(), nullptr, nullptr); -} - -void NodeManager::ProcessDisconnectClientMessage(const WorkerID &worker_id, - bool intentional_disconnect) { - auto worker = worker_pool_.GetWorker(worker_id); - if (!worker) { - RAY_LOG(INFO) << "Ignoring client disconnect because the client has already " - << "been disconnected."; - return; +void NodeManager::HandleWorkerAvailable( + const std::shared_ptr &client) { + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + RAY_CHECK(worker); + // If the worker was assigned a task, mark it as finished. + if (!worker->GetAssignedTaskId().IsNil()) { + FinishAssignedTask(*worker); } - bool is_worker = worker->IsWorker(); + + // Return the worker to the idle pool. + worker_pool_.PushWorker(std::move(worker)); + // Local resource availability changed: invoke scheduling policy for local node. + const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); + cluster_resource_map_[local_client_id].SetLoadResources( + local_queues_.GetResourceLoad()); + // Call task dispatch to assign work to the new worker. + DispatchTasks(local_queues_.GetReadyTasksWithResources()); +} + +void NodeManager::ProcessDisconnectClientMessage( + const std::shared_ptr &client, bool intentional_disconnect) { + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + bool is_worker = false, is_driver = false; + if (worker) { + // The client is a worker. + is_worker = true; + } else { + worker = worker_pool_.GetRegisteredDriver(client); + if (worker) { + // The client is a driver. + is_driver = true; + } else { + RAY_LOG(INFO) << "Ignoring client disconnect because the client has already " + << "been disconnected."; + } + } + RAY_CHECK(!(is_worker && is_driver)); // If the client has any blocked tasks, mark them as unblocked. In // particular, we are no longer waiting for their dependencies. - if (is_worker && worker->IsBeingKilled()) { - // Don't need to unblock the client if it's a worker and have sent kill signal to - // it. Because in this case, its task is already cleaned up. - RAY_LOG(DEBUG) << "Skip unblocking worker because it's already dead."; - } else { - // Clean up any open ray.get calls that the worker made. - while (!worker->GetBlockedTaskIds().empty()) { - // NOTE(swang): HandleTaskUnblocked will modify the worker, so it is - // not safe to pass in the iterator directly. - const TaskID task_id = *worker->GetBlockedTaskIds().begin(); - HandleTaskUnblocked(worker_id, task_id); + if (worker) { + if (is_worker && worker->IsDead()) { + // Don't need to unblock the client if it's a worker and is already dead. + // Because in this case, its task is already cleaned up. + RAY_LOG(DEBUG) << "Skip unblocking worker because it's already dead."; + } else { + // Clean up any open ray.get calls that the worker made. + while (!worker->GetBlockedTaskIds().empty()) { + // NOTE(swang): HandleTaskUnblocked will modify the worker, so it is + // not safe to pass in the iterator directly. + const TaskID task_id = *worker->GetBlockedTaskIds().begin(); + HandleTaskUnblocked(client, task_id); + } + // Clean up any open ray.wait calls that the worker made. + task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId()); } - // Clean up any open ray.wait calls that the worker made. - task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId()); } if (is_worker) { + // The client is a worker. + if (worker->IsDead()) { + // If the worker was killed by us because the driver exited, + // treat it as intentionally disconnected. + intentional_disconnect = true; + } + const ActorID &actor_id = worker->GetActorId(); if (!actor_id.IsNil()) { - // If the worker was an actor, update actor state, reconstruct the actor if - // needed, and clean up actor's tasks if the actor is permanently dead. + // If the worker was an actor, update actor state, reconstruct the actor if needed, + // and clean up actor's tasks if the actor is permanently dead. HandleDisconnectedActor(actor_id, true, intentional_disconnect); } const TaskID &task_id = worker->GetAssignedTaskId(); // If the worker was running a task, clean up the task and push an error to - // the driver, unless the worker is already being killed. - if (!task_id.IsNil() && !worker->IsBeingKilled()) { + // the driver, unless the worker is already dead. + if (!task_id.IsNil() && !worker->IsDead()) { // If the worker was an actor, the task was already cleaned up in // `HandleDisconnectedActor`. if (actor_id.IsNil()) { @@ -984,7 +1022,7 @@ void NodeManager::ProcessDisconnectClientMessage(const WorkerID &worker_id, // Since some resources may have been released, we can try to dispatch more tasks. DispatchTasks(local_queues_.GetReadyTasksWithResources()); - } else { + } else if (is_driver) { // The client is a driver. const auto job_id = worker->GetAssignedJobId(); const auto driver_id = ComputeDriverIdFromJob(job_id); @@ -999,34 +1037,32 @@ void NodeManager::ProcessDisconnectClientMessage(const WorkerID &worker_id, << "job_id: " << job_id; } + client->Close(); + // TODO(rkn): Tell the object manager that this client has disconnected so // that it can clean up the wait requests for this client. Currently I think // these can be leaked. } -void NodeManager::HandleSubmitTaskRequest(const rpc::SubmitTaskRequest &request, - rpc::SubmitTaskReply *reply, - rpc::SendReplyCallback send_reply_callback) { - PREPROCESS_WORKER_REQUEST(SubmitTaskRequest, request, send_reply_callback); - rpc::Task task; - task.mutable_task_spec()->CopyFrom(request.task_spec()); +void NodeManager::ProcessSubmitTaskMessage(const uint8_t *message_data) { + // Read the task submitted by the client. + auto fbs_message = flatbuffers::GetRoot(message_data); + rpc::Task task_message; + RAY_CHECK(task_message.mutable_task_spec()->ParseFromArray( + fbs_message->task_spec()->data(), fbs_message->task_spec()->size())); // Submit the task to the raylet. Since the task was submitted // locally, there is no uncommitted lineage. - SubmitTask(Task(task), Lineage()); - send_reply_callback(Status::OK(), nullptr, nullptr); + SubmitTask(Task(task_message), Lineage()); } -void NodeManager::HandleFetchOrReconstructRequest( - const rpc::FetchOrReconstructRequest &request, rpc::FetchOrReconstructReply *reply, - rpc::SendReplyCallback send_reply_callback) { - PREPROCESS_WORKER_REQUEST(FetchOrReconstructRequest, request, send_reply_callback); - WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); - const auto &object_ids = request.object_ids(); +void NodeManager::ProcessFetchOrReconstructMessage( + const std::shared_ptr &client, const uint8_t *message_data) { + auto message = flatbuffers::GetRoot(message_data); std::vector required_object_ids; - for (int64_t i = 0; i < object_ids.size(); ++i) { - ObjectID object_id = ObjectID::FromBinary(object_ids[i]); - if (request.fetch_only()) { + for (int64_t i = 0; i < message->object_ids()->size(); ++i) { + ObjectID object_id = from_flatbuf(*message->object_ids()->Get(i)); + if (message->fetch_only()) { // If only a fetch is required, then do not subscribe to the // dependencies to the task dependency manager. if (!task_dependency_manager_.CheckObjectLocal(object_id)) { @@ -1043,22 +1079,19 @@ void NodeManager::HandleFetchOrReconstructRequest( } if (!required_object_ids.empty()) { - const TaskID task_id = TaskID::FromBinary(request.task_id()); - HandleTaskBlocked(worker_id, required_object_ids, task_id, /*ray_get=*/true); + const TaskID task_id = from_flatbuf(*message->task_id()); + HandleTaskBlocked(client, required_object_ids, task_id, /*ray_get=*/true); } - send_reply_callback(Status::OK(), nullptr, nullptr); } -void NodeManager::HandleWaitRequest(const rpc::WaitRequest &request, - rpc::WaitReply *reply, - rpc::SendReplyCallback send_reply_callback) { - PREPROCESS_WORKER_REQUEST(WaitRequest, request, send_reply_callback); - WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); +void NodeManager::ProcessWaitRequestMessage( + const std::shared_ptr &client, const uint8_t *message_data) { // Read the data. - std::vector object_ids = IdVectorFromProtobuf(request.object_ids()); - int64_t wait_ms = request.timeout(); - uint64_t num_required_objects = request.num_ready_objects(); - bool wait_local = request.wait_local(); + auto message = flatbuffers::GetRoot(message_data); + std::vector object_ids = from_flatbuf(*message->object_ids()); + int64_t wait_ms = message->timeout(); + uint64_t num_required_objects = static_cast(message->num_ready_objects()); + bool wait_local = message->wait_local(); std::vector required_object_ids; for (auto const &object_id : object_ids) { @@ -1070,57 +1103,63 @@ void NodeManager::HandleWaitRequest(const rpc::WaitRequest &request, } } - const TaskID ¤t_task_id = TaskID::FromBinary(request.task_id()); + const TaskID ¤t_task_id = from_flatbuf(*message->task_id()); bool client_blocked = !required_object_ids.empty(); if (client_blocked) { - HandleTaskBlocked(worker_id, required_object_ids, current_task_id, /*ray_get=*/false); + HandleTaskBlocked(client, required_object_ids, current_task_id, /*ray_get=*/false); } ray::Status status = object_manager_.Wait( object_ids, wait_ms, num_required_objects, wait_local, - [this, client_blocked, worker_id, current_task_id, reply, send_reply_callback]( - std::vector found, std::vector remaining) { - IdVectorToProtobuf(found, *reply, - &rpc::WaitReply::add_found); - IdVectorToProtobuf(remaining, *reply, - &rpc::WaitReply::add_remaining); + [this, client_blocked, client, current_task_id](std::vector found, + std::vector remaining) { + // Write the data. + flatbuffers::FlatBufferBuilder fbb; + flatbuffers::Offset wait_reply = protocol::CreateWaitReply( + fbb, to_flatbuf(fbb, found), to_flatbuf(fbb, remaining)); + fbb.Finish(wait_reply); - // Send reply to finish this wait request. - send_reply_callback(Status::OK(), nullptr, nullptr); - if (client_blocked) { - HandleTaskUnblocked(worker_id, current_task_id); + auto status = + client->WriteMessage(static_cast(protocol::MessageType::WaitReply), + fbb.GetSize(), fbb.GetBufferPointer()); + if (status.ok()) { + // The client is unblocked now because the wait call has returned. + if (client_blocked) { + HandleTaskUnblocked(client, current_task_id); + } + } else { + // We failed to write to the client, so disconnect the client. + RAY_LOG(WARNING) + << "Failed to send WaitReply to client, so disconnecting client"; + // We failed to send the reply to the client, so disconnect the worker. + ProcessDisconnectClientMessage(client); } }); RAY_CHECK_OK(status); } -void NodeManager::HandlePushErrorRequest(const rpc::PushErrorRequest &request, - rpc::PushErrorReply *reply, - rpc::SendReplyCallback send_reply_callback) { - PREPROCESS_WORKER_REQUEST(PushErrorRequest, request, send_reply_callback); - JobID job_id = JobID::FromBinary(request.job_id()); - const auto &type = request.type(); - const auto &error_message = request.error_message(); - double timestamp = request.timestamp(); - RAY_LOG(DEBUG) << "Handle push error request for job " << job_id << ", type " << type - << " error message " << error_message; +void NodeManager::ProcessPushErrorRequestMessage(const uint8_t *message_data) { + auto message = flatbuffers::GetRoot(message_data); + + JobID job_id = from_flatbuf(*message->job_id()); + auto const &type = string_from_flatbuf(*message->type()); + auto const &error_message = string_from_flatbuf(*message->error_message()); + double timestamp = message->timestamp(); RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(job_id, type, error_message, timestamp)); - send_reply_callback(Status::OK(), nullptr, nullptr); } -void NodeManager::HandlePrepareActorCheckpointRequest( - const rpc::PrepareActorCheckpointRequest &request, - rpc::PrepareActorCheckpointReply *reply, rpc::SendReplyCallback send_reply_callback) { - PREPROCESS_WORKER_REQUEST(PrepareActorCheckpointRequest, request, send_reply_callback); - WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); - ActorID actor_id = ActorID::FromBinary(request.actor_id()); +void NodeManager::ProcessPrepareActorCheckpointRequest( + const std::shared_ptr &client, const uint8_t *message_data) { + auto message = + flatbuffers::GetRoot(message_data); + ActorID actor_id = from_flatbuf(*message->actor_id()); RAY_LOG(DEBUG) << "Preparing checkpoint for actor " << actor_id; const auto &actor_entry = actor_registry_.find(actor_id); RAY_CHECK(actor_entry != actor_registry_.end()); - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(worker_id); + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); RAY_CHECK(worker && worker->GetActorId() == actor_id); // Find the task that is running on this actor. @@ -1134,36 +1173,40 @@ void NodeManager::HandlePrepareActorCheckpointRequest( // Write checkpoint data to GCS. RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Add( JobID::Nil(), checkpoint_id, checkpoint_data, - [worker, actor_id, reply, send_reply_callback, this]( - ray::gcs::RedisGcsClient *client, const ActorCheckpointID &checkpoint_id, - const ActorCheckpointData &data) { + [worker, actor_id, this](ray::gcs::RedisGcsClient *client, + const ActorCheckpointID &checkpoint_id, + const ActorCheckpointData &data) { RAY_LOG(DEBUG) << "Checkpoint " << checkpoint_id << " saved for actor " << worker->GetActorId(); - // Save this actor-to-checkpoint mapping, and remove old checkpoints - // associated with this actor. + // Save this actor-to-checkpoint mapping, and remove old checkpoints associated + // with this actor. RAY_CHECK_OK(gcs_client_->actor_checkpoint_id_table().AddCheckpointId( JobID::Nil(), actor_id, checkpoint_id)); // Send reply to worker. - reply->set_checkpoint_id(checkpoint_id.Binary()); - send_reply_callback(Status::OK(), nullptr, []() { - RAY_LOG(WARNING) << "Failed to send PrepareActorCheckpointReply to client"; - }); + flatbuffers::FlatBufferBuilder fbb; + auto reply = ray::protocol::CreatePrepareActorCheckpointReply( + fbb, to_flatbuf(fbb, checkpoint_id)); + fbb.Finish(reply); + worker->Connection()->WriteMessageAsync( + static_cast(protocol::MessageType::PrepareActorCheckpointReply), + fbb.GetSize(), fbb.GetBufferPointer(), [](const ray::Status &status) { + if (!status.ok()) { + RAY_LOG(WARNING) + << "Failed to send PrepareActorCheckpointReply to client"; + } + }); })); } -void NodeManager::HandleNotifyActorResumedFromCheckpointRequest( - const rpc::NotifyActorResumedFromCheckpointRequest &request, - rpc::NotifyActorResumedFromCheckpointReply *reply, - rpc::SendReplyCallback send_reply_callback) { - PREPROCESS_WORKER_REQUEST(NotifyActorResumedFromCheckpointRequest, request, - send_reply_callback); - ActorID actor_id = ActorID::FromBinary(request.actor_id()); +void NodeManager::ProcessNotifyActorResumedFromCheckpoint(const uint8_t *message_data) { + auto message = + flatbuffers::GetRoot(message_data); + ActorID actor_id = from_flatbuf(*message->actor_id()); ActorCheckpointID checkpoint_id = - ActorCheckpointID::FromBinary(request.checkpoint_id()); + from_flatbuf(*message->checkpoint_id()); RAY_LOG(DEBUG) << "Actor " << actor_id << " was resumed from checkpoint " << checkpoint_id; checkpoint_id_to_restore_.emplace(actor_id, checkpoint_id); - send_reply_callback(Status::OK(), nullptr, nullptr); } void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request, @@ -1184,15 +1227,16 @@ void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request, send_reply_callback(Status::OK(), nullptr, nullptr); } -void NodeManager::HandleSetResourceRequest(const rpc::SetResourceRequest &request, - rpc::SetResourceReply *reply, - rpc::SendReplyCallback send_reply_callback) { - PREPROCESS_WORKER_REQUEST(SetResourceRequest, request, send_reply_callback); - auto const &resource_name = request.resource_name(); - double const capacity = request.capacity(); +void NodeManager::ProcessSetResourceRequest( + const std::shared_ptr &client, const uint8_t *message_data) { + // Read the SetResource message + auto message = flatbuffers::GetRoot(message_data); + + auto const &resource_name = string_from_flatbuf(*message->resource_name()); + double const &capacity = message->capacity(); bool is_deletion = capacity <= 0; - ClientID client_id = ClientID::FromBinary(request.client_id()); + ClientID client_id = from_flatbuf(*message->client_id()); // If the python arg was null, set client_id to the local client if (client_id.IsNil()) { @@ -1222,62 +1266,6 @@ void NodeManager::HandleSetResourceRequest(const rpc::SetResourceRequest &reques RAY_CHECK_OK( gcs_client_->resource_table().Update(JobID::Nil(), client_id, data_map, nullptr)); } - send_reply_callback(Status::OK(), nullptr, nullptr); -} - -void NodeManager::HandleNotifyUnblockedRequest( - const rpc::NotifyUnblockedRequest &request, rpc::NotifyUnblockedReply *reply, - rpc::SendReplyCallback send_reply_callback) { - PREPROCESS_WORKER_REQUEST(NotifyUnblockedRequest, request, send_reply_callback); - WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); - const TaskID current_task_id = TaskID::FromBinary(request.task_id()); - - HandleTaskUnblocked(worker_id, current_task_id); - send_reply_callback(Status::OK(), nullptr, nullptr); -} - -void NodeManager::HandlePushProfileEventsRequest( - const rpc::PushProfileEventsRequest &request, rpc::PushProfileEventsReply *reply, - rpc::SendReplyCallback send_reply_callback) { - PREPROCESS_WORKER_REQUEST(PushProfileEventsRequest, request, send_reply_callback); - const auto &profile_table_data = request.profile_table_data(); - RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_table_data)); - send_reply_callback(Status::OK(), nullptr, nullptr); -} - -void NodeManager::HandleFreeObjectsInStoreRequest( - const rpc::FreeObjectsInStoreRequest &request, rpc::FreeObjectsInStoreReply *reply, - rpc::SendReplyCallback send_reply_callback) { - PREPROCESS_WORKER_REQUEST(FreeObjectsInStoreRequest, request, send_reply_callback); - std::vector object_ids = IdVectorFromProtobuf(request.object_ids()); - object_manager_.FreeObjects(object_ids, request.local_only()); - if (request.delete_creating_tasks()) { - // Clean up their creating tasks from GCS. - std::vector creating_task_ids; - for (const auto &object_id : object_ids) { - creating_task_ids.push_back(object_id.TaskId()); - } - gcs_client_->raylet_task_table().Delete(JobID::Nil(), creating_task_ids); - } - send_reply_callback(Status::OK(), nullptr, nullptr); -} - -void NodeManager::HandleHeartbeatRequest(const rpc::HeartbeatRequest &request, - rpc::HeartbeatReply *reply, - rpc::SendReplyCallback send_reply_callback) { - bool is_worker = request.is_worker(); - const auto worker_id = WorkerID::FromBinary(request.worker_id()); - - std::shared_ptr worker = nullptr; - if (is_worker) { - worker = worker_pool_.GetRegisteredWorker(worker_id); - } else { - worker = worker_pool_.GetRegisteredDriver(worker_id); - } - if (worker) { - worker->ClearHeartbeat(); - } - send_reply_callback(Status::OK(), nullptr, nullptr); } void NodeManager::ScheduleTasks( @@ -1584,8 +1572,7 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // This is a non-actor task. Queue the task for a placement decision or for dispatch // if the task was forwarded. if (forwarded) { - // Check for local dependencies and enqueue as waiting or ready for - // dispatch. + // Check for local dependencies and enqueue as waiting or ready for dispatch. EnqueuePlaceableTask(task); } else { // (See design_docs/task_states.rst for the state transition diagram.) @@ -1597,10 +1584,10 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag } } -void NodeManager::HandleTaskBlocked(const WorkerID &worker_id, +void NodeManager::HandleTaskBlocked(const std::shared_ptr &client, const std::vector &required_object_ids, const TaskID ¤t_task_id, bool ray_get) { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(worker_id); + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); if (worker) { // The client is a worker. If the worker is not already blocked and the // blocked task matches the one assigned to the worker, then mark the @@ -1624,7 +1611,7 @@ void NodeManager::HandleTaskBlocked(const WorkerID &worker_id, } else { // The client is a driver. Drivers do not hold resources, so we simply mark // the task as blocked. - worker = worker_pool_.GetRegisteredDriver(worker_id); + worker = worker_pool_.GetRegisteredDriver(client); } RAY_CHECK(worker); @@ -1646,9 +1633,9 @@ void NodeManager::HandleTaskBlocked(const WorkerID &worker_id, } } -void NodeManager::HandleTaskUnblocked(const WorkerID &worker_id, - const TaskID ¤t_task_id) { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(worker_id); +void NodeManager::HandleTaskUnblocked( + const std::shared_ptr &client, const TaskID ¤t_task_id) { + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); // TODO(swang): Because the object dependencies are tracked in the task // dependency manager, we could actually remove this message entirely and @@ -1678,9 +1665,9 @@ void NodeManager::HandleTaskUnblocked(const WorkerID &worker_id, cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Acquire( cpu_resources); } else { - // In this case, we simply don't reacquire the CPU resources for the - // worker. The worker can keep running and when the task finishes, it will - // simply not have any CPU resources to release. + // In this case, we simply don't reacquire the CPU resources for the worker. + // The worker can keep running and when the task finishes, it will simply + // not have any CPU resources to release. RAY_LOG(WARNING) << "Resources oversubscribed: " << cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()] @@ -1692,7 +1679,7 @@ void NodeManager::HandleTaskUnblocked(const WorkerID &worker_id, } else { // The client is a driver. Drivers do not hold resources, so we simply // mark the driver as unblocked. - worker = worker_pool_.GetRegisteredDriver(worker_id); + worker = worker_pool_.GetRegisteredDriver(client); } // Unsubscribe from any `ray.get` objects that the task was blocked on. Any @@ -1727,24 +1714,6 @@ void NodeManager::EnqueuePlaceableTask(const Task &task) { task_dependency_manager_.TaskPending(task); } -void NodeManager::HandleWorkerAvailable(const WorkerID &worker_id) { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(worker_id); - RAY_CHECK(worker); - // If the worker was assigned a task, mark it as finished. - if (!worker->GetAssignedTaskId().IsNil()) { - FinishAssignedTask(*worker); - } - - // Return the worker to the idle pool. - worker_pool_.PushWorker(std::move(worker)); - // Local resource availability changed: invoke scheduling policy for local node. - const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); - cluster_resource_map_[local_client_id].SetLoadResources( - local_queues_.GetResourceLoad()); - // Call task dispatch to assign work to the new worker. - DispatchTasks(local_queues_.GetReadyTasksWithResources()); -} - bool NodeManager::AssignTask(const Task &task) { const TaskSpecification &spec = task.GetTaskSpecification(); @@ -1764,12 +1733,12 @@ bool NodeManager::AssignTask(const Task &task) { if (worker == nullptr) { // There are no workers that can execute this task. // We couldn't assign this task, as no worker available. - RAY_LOG(DEBUG) << "No idle worker is found to assign task " << spec.TaskId(); return false; } - RAY_LOG(DEBUG) << "Assigning task " << spec.TaskId() << " to worker " - << worker->GetWorkerId() << " pid " << worker->Pid(); + RAY_LOG(DEBUG) << "Assigning task " << spec.TaskId() << " to worker with pid " + << worker->Pid(); + flatbuffers::FlatBufferBuilder fbb; // Resource accounting: acquire resources for the assigned task. auto acquired_resources = @@ -1786,44 +1755,40 @@ bool NodeManager::AssignTask(const Task &task) { worker->SetTaskResourceIds(acquired_resources); } + auto task_id = spec.TaskId(); + auto finish_assign_task_callback = [this, worker, task_id](Status status) { + if (worker->UsePush()) { + // NOTE: we cannot directly call `FinishAssignTask` here because + // it assumes the task is in SWAP queue, thus we need to delay invoking this + // function after the assigned tasks are moved from READY queue to SWAP queue + // in `DispatchTasks`. + // Another option is to move the tasks to SWAP queue here just before calling + // `FinishAssignTask` so we can save an io_service post, at the + // expense of calling `MoveTask` for each of the assigned tasks. + // TODO(zhijunfu): after all workers are fully migrated to push mode, the + // `post` below and swap queue can be removed. + io_service_.post([this, status, worker, task_id]() { + FinishAssignTask(task_id, *worker, status.ok()); + }); + } else { + FinishAssignTask(task_id, *worker, status.ok()); + } + }; + ResourceIdSet resource_id_set = worker->GetTaskResourceIds().Plus(worker->GetLifetimeResourceIds()); - worker->AssignTask(task, resource_id_set); - // Actor tasks require extra accounting to track the actor's state. - if (spec.IsActorTask()) { - auto actor_entry = actor_registry_.find(spec.ActorId()); - RAY_CHECK(actor_entry != actor_registry_.end()); - // Process any new actor handles that were created since the - // previous task on this handle was executed. The first task - // submitted on a new actor handle will depend on the dummy object - // returned by the previous task, so the dependency will not be - // released until this first task is submitted. - for (auto &new_handle_id : spec.NewActorHandles()) { - const auto prev_actor_task_id = spec.PreviousActorTaskDummyObjectId(); - RAY_CHECK(!prev_actor_task_id.IsNil()); - // Add the new handle and give it a reference to the finished task's - // execution dependency. - actor_entry->second.AddHandle(new_handle_id, prev_actor_task_id); - } - - // TODO(swang): For actors with multiple actor handles, to - // guarantee that tasks are replayed in the same order after a - // failure, we must update the task's execution dependency to be - // the actor's current execution dependency. - } - - // Notify the task dependency manager that we no longer need this task's - // object dependencies. - RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId())); + worker->AssignTask(task, resource_id_set, finish_assign_task_callback); + // We assigned this task to a worker. + // (Note this means that we sent the task to the worker. The assignment + // might still fail if the worker fails in the meantime, for instance.) return true; } void NodeManager::FinishAssignedTask(Worker &worker) { TaskID task_id = worker.GetAssignedTaskId(); - RAY_LOG(DEBUG) << "Finished task " << task_id << " from worker " << worker.GetWorkerId() - << " with pid " << worker.Pid(); + RAY_LOG(DEBUG) << "Finished task " << task_id; // (See design_docs/task_states.rst for the state transition diagram.) Task task; @@ -2237,9 +2202,8 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, retry_timer->expires_from_now(retry_duration); retry_timer->async_wait( [this, task_id, retry_timer](const boost::system::error_code &error) { - // Timer killing will receive the - // boost::asio::error::operation_aborted, we only handle the timeout - // event. + // Timer killing will receive the boost::asio::error::operation_aborted, + // we only handle the timeout event. RAY_CHECK(!error); RAY_LOG(INFO) << "Resubmitting task " << task_id << " because ForwardTask failed."; @@ -2361,6 +2325,63 @@ void NodeManager::ForwardTask( }); } +void NodeManager::FinishAssignTask(const TaskID &task_id, Worker &worker, bool success) { + // Remove the ASSIGNED task from the SWAP queue. + Task assigned_task; + TaskState state; + if (!local_queues_.RemoveTask(task_id, &assigned_task, &state)) { + return; + } + + RAY_CHECK(state == TaskState::SWAP); + + if (success) { + auto spec = assigned_task.GetTaskSpecification(); + // We successfully assigned the task to the worker. + worker.AssignTaskId(spec.TaskId()); + worker.AssignJobId(spec.JobId()); + // Actor tasks require extra accounting to track the actor's state. + if (spec.IsActorTask()) { + auto actor_entry = actor_registry_.find(spec.ActorId()); + RAY_CHECK(actor_entry != actor_registry_.end()); + // Process any new actor handles that were created since the + // previous task on this handle was executed. The first task + // submitted on a new actor handle will depend on the dummy object + // returned by the previous task, so the dependency will not be + // released until this first task is submitted. + for (auto &new_handle_id : spec.NewActorHandles()) { + const auto prev_actor_task_id = spec.PreviousActorTaskDummyObjectId(); + RAY_CHECK(!prev_actor_task_id.IsNil()); + // Add the new handle and give it a reference to the finished task's + // execution dependency. + actor_entry->second.AddHandle(new_handle_id, prev_actor_task_id); + } + + // TODO(swang): For actors with multiple actor handles, to + // guarantee that tasks are replayed in the same order after a + // failure, we must update the task's execution dependency to be + // the actor's current execution dependency. + } + + // Mark the task as running. + // (See design_docs/task_states.rst for the state transition diagram.) + local_queues_.QueueTasks({assigned_task}, TaskState::RUNNING); + // Notify the task dependency manager that we no longer need this task's + // object dependencies. + RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId())); + } else { + RAY_LOG(WARNING) << "Failed to send task to worker, disconnecting client"; + // We failed to send the task to the worker, so disconnect the worker. + ProcessDisconnectClientMessage(worker.Connection()); + // Queue this task for future assignment. We need to do this since + // DispatchTasks() removed it from the ready queue. The task will be + // assigned to a worker once one becomes available. + // (See design_docs/task_states.rst for the state transition diagram.) + local_queues_.QueueTasks({assigned_task}, TaskState::READY); + DispatchTasks(MakeTasksWithResources({assigned_task})); + } +} + void NodeManager::DumpDebugState() const { std::fstream fs; fs.open(initial_config_.session_dir + "/debug_state.txt", diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index a6093eadd..c063d9182 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -7,8 +7,6 @@ #include "ray/rpc/client_call.h" #include "ray/rpc/node_manager/node_manager_server.h" #include "ray/rpc/node_manager/node_manager_client.h" -#include "ray/rpc/raylet/raylet_server.h" -#include "ray/object_manager/object_manager.h" #include "ray/common/task/task.h" #include "ray/common/client_connection.h" #include "ray/common/task/task_common.h" @@ -66,8 +64,7 @@ struct NodeManagerConfig { std::string session_dir; }; -class NodeManager : public rpc::NodeManagerServiceHandler, - public rpc::RayletServiceHandler { +class NodeManager : public rpc::NodeManagerServiceHandler { public: /// Create a node manager. /// @@ -78,6 +75,23 @@ class NodeManager : public rpc::NodeManagerServiceHandler, std::shared_ptr gcs_client, std::shared_ptr object_directory_); + /// Process a new client connection. + /// + /// \param client The client to process. + /// \return Void. + void ProcessNewClient(LocalClientConnection &client); + + /// Process a message from a client. This method is responsible for + /// explicitly listening for more messages from the client if the client is + /// still alive. + /// + /// \param client The client that sent the message. + /// \param message_type The message type (e.g., a flatbuffer enum). + /// \param message_data A pointer to the message data. + /// \return Void. + void ProcessClientMessage(const std::shared_ptr &client, + int64_t message_type, const uint8_t *message_data); + /// Subscribe to the relevant GCS tables and set up handlers. /// /// \return Status indicating whether this was done successfully or not. @@ -91,91 +105,9 @@ class NodeManager : public rpc::NodeManagerServiceHandler, /// Record metrics. void RecordMetrics() const; - public: /// Get the port of the node manager rpc server. int GetServerPort() const { return node_manager_server_.GetPort(); } - /// Preprocess request from raylet client. We will check whether the worker is being - /// killed due to job finishing. - /// - /// \param worker_id The worker id. - /// \param request_name The request name. - /// \return False if there is no need to process this request. - bool PreprocessRequest(const WorkerID &worker_id, const std::string &request_name); - - /// Implementation of node manager grpc service. - - /// Handle a `ForwardTask` request. - /// - /// \param request The request. - /// \param reply The reply that will be sent to client. - /// \param send_reply_callback Invoke this callback to send reply asynchronously. - void HandleForwardTask(const rpc::ForwardTaskRequest &request, - rpc::ForwardTaskReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - - /// Implementation of raylet grpc service handlers, please see definitions - /// in `src/ray/protobuf/raylet.proto` for more details. - void HandleRegisterClientRequest(const rpc::RegisterClientRequest &request, - rpc::RegisterClientReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - - void HandleSubmitTaskRequest(const rpc::SubmitTaskRequest &request, - rpc::SubmitTaskReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - - void HandleDisconnectClientRequest(const rpc::DisconnectClientRequest &request, - rpc::DisconnectClientReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - - void HandleGetTaskRequest(const rpc::GetTaskRequest &request, rpc::GetTaskReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - - void HandleTaskDoneRequest(const rpc::TaskDoneRequest &request, - rpc::TaskDoneReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - - void HandleFetchOrReconstructRequest( - const rpc::FetchOrReconstructRequest &request, rpc::FetchOrReconstructReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - - void HandleNotifyUnblockedRequest(const rpc::NotifyUnblockedRequest &request, - rpc::NotifyUnblockedReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - - void HandleWaitRequest(const rpc::WaitRequest &request, rpc::WaitReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - - void HandlePushErrorRequest(const rpc::PushErrorRequest &request, - rpc::PushErrorReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - - void HandlePushProfileEventsRequest( - const rpc::PushProfileEventsRequest &request, rpc::PushProfileEventsReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - - void HandleFreeObjectsInStoreRequest( - const rpc::FreeObjectsInStoreRequest &request, rpc::FreeObjectsInStoreReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - - void HandlePrepareActorCheckpointRequest( - const rpc::PrepareActorCheckpointRequest &request, - rpc::PrepareActorCheckpointReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - - void HandleNotifyActorResumedFromCheckpointRequest( - const rpc::NotifyActorResumedFromCheckpointRequest &request, - rpc::NotifyActorResumedFromCheckpointReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - - void HandleSetResourceRequest(const rpc::SetResourceRequest &request, - rpc::SetResourceReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - - void HandleHeartbeatRequest(const rpc::HeartbeatRequest &request, - rpc::HeartbeatReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - private: /// Methods for handling clients. @@ -276,11 +208,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler, /// \return Void. void SubmitTask(const Task &task, const Lineage &uncommitted_lineage, bool forwarded = false); - /// Handle the case that a worker is available. - /// - /// \param id Id of the worker. - /// \return Void. - void HandleWorkerAvailable(const WorkerID &worker_id); /// Assign a task. The task is assumed to not be queued in local_queues_. /// /// \param task The task in question. @@ -392,7 +319,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler, /// \param ray_get Whether the task is blocked in a `ray.get` call, as /// opposed to a `ray.wait` call. /// \return Void. - void HandleTaskBlocked(const WorkerID &worker_id, + void HandleTaskBlocked(const std::shared_ptr &client, const std::vector &required_object_ids, const TaskID ¤t_task_id, bool ray_get); @@ -405,7 +332,8 @@ class NodeManager : public rpc::NodeManagerServiceHandler, /// \param client The client that is executing the unblocked task. /// \param current_task_id The task that is unblocked. /// \return Void. - void HandleTaskUnblocked(const WorkerID &worker_id, const TaskID ¤t_task_id); + void HandleTaskUnblocked(const std::shared_ptr &client, + const TaskID ¤t_task_id); /// Kill a worker. /// @@ -481,8 +409,9 @@ class NodeManager : public rpc::NodeManagerServiceHandler, /// \param client The client that sent the message. /// \param intentional_disconnect Whether the client was intentionally disconnected. /// \return Void. - void ProcessDisconnectClientMessage(const WorkerID &worker_id, - bool intentional_disconnect = false); + void ProcessDisconnectClientMessage( + const std::shared_ptr &client, + bool intentional_disconnect = false); /// Process client message of SubmitTask /// @@ -550,6 +479,19 @@ class NodeManager : public rpc::NodeManagerServiceHandler, void HandleDisconnectedActor(const ActorID &actor_id, bool was_local, bool intentional_disconnect); + /// Finish assigning a task to a worker. + /// + /// \param task_id Id of the task. + /// \param worker Worker which the task is assigned to. + /// \param success Whether the task is successfully assigned to the worker. + /// \return void. + void FinishAssignTask(const TaskID &task_id, Worker &worker, bool success); + + /// Handle a `ForwardTask` request. + void HandleForwardTask(const rpc::ForwardTaskRequest &request, + rpc::ForwardTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + // GCS client ID for this node. ClientID client_id_; boost::asio::io_service &io_service_; @@ -609,9 +551,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler, /// The node manager RPC service. rpc::NodeManagerGrpcService node_manager_service_; - /// The raylet RPC service. - rpc::RayletGrpcService raylet_service_; - /// The `ClientCallManager` object that is shared by all `NodeManagerClient`s /// as well as all `WorkerTaskClient`s. rpc::ClientCallManager client_call_manager_; diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index bcca12411..56b509240 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -7,6 +7,34 @@ #include "ray/common/status.h" +namespace { + +const std::vector GenerateEnumNames(const char *const *enum_names_ptr, + int start_index, int end_index) { + std::vector enum_names; + for (int i = 0; i < start_index; ++i) { + enum_names.push_back("EmptyMessageType"); + } + size_t i = 0; + while (true) { + const char *name = enum_names_ptr[i]; + if (name == nullptr) { + break; + } + enum_names.push_back(name); + i++; + } + RAY_CHECK(static_cast(end_index) == enum_names.size() - 1) + << "Message Type mismatch!"; + return enum_names; +} + +static const std::vector node_manager_message_enum = + GenerateEnumNames(ray::protocol::EnumNamesMessageType(), + static_cast(ray::protocol::MessageType::MIN), + static_cast(ray::protocol::MessageType::MAX)); +} // namespace + namespace ray { namespace raylet { @@ -23,10 +51,10 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_ node_manager_(main_service, node_manager_config, object_manager_, gcs_client_, object_directory_), socket_name_(socket_name), - raylet_server_("Raylet", socket_name), - raylet_service_(main_service, node_manager_) { - raylet_server_.RegisterService(raylet_service_); - raylet_server_.Run(); + acceptor_(main_service, boost::asio::local::stream_protocol::endpoint(socket_name)), + socket_(main_service) { + // Start listening for clients. + DoAccept(); RAY_CHECK_OK(RegisterGcs( node_ip_address, socket_name_, object_manager_config.store_socket_name, @@ -80,6 +108,31 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, return Status::OK(); } +void Raylet::DoAccept() { + acceptor_.async_accept(socket_, boost::bind(&Raylet::HandleAccept, this, + boost::asio::placeholders::error)); +} + +void Raylet::HandleAccept(const boost::system::error_code &error) { + if (!error) { + // TODO: typedef these handlers. + ClientHandler client_handler = + [this](LocalClientConnection &client) { node_manager_.ProcessNewClient(client); }; + MessageHandler message_handler = + [this](std::shared_ptr client, int64_t message_type, + const uint8_t *message) { + node_manager_.ProcessClientMessage(client, message_type, message); + }; + // Accept a new local client and dispatch it to the node manager. + auto new_connection = LocalClientConnection::Create( + client_handler, message_handler, std::move(socket_), "worker", + node_manager_message_enum, + static_cast(protocol::MessageType::DisconnectClient)); + } + // We're ready to accept another client. + DoAccept(); +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index 7962d5114..ec7fe74cf 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -75,12 +75,10 @@ class Raylet { /// The name of the socket this raylet listens on. std::string socket_name_; - /// The gRPC server, listens for local raylet client connections through a unix domain - /// socket. - rpc::GrpcServer raylet_server_; - - /// The gRPC service. - rpc::RayletGrpcService raylet_service_; + /// An acceptor for new clients. + boost::asio::local::stream_protocol::acceptor acceptor_; + /// The socket to listen on for new clients. + boost::asio::local::stream_protocol::socket socket_; }; } // namespace raylet diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc new file mode 100644 index 000000000..1c8871bf0 --- /dev/null +++ b/src/ray/raylet/raylet_client.cc @@ -0,0 +1,392 @@ +#include "raylet_client.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ray/common/common_protocol.h" +#include "ray/common/ray_config.h" +#include "ray/common/task/task_spec.h" +#include "ray/raylet/format/node_manager_generated.h" +#include "ray/util/logging.h" + +using MessageType = ray::protocol::MessageType; + +// TODO(rkn): The io methods below should be removed. +int connect_ipc_sock(const std::string &socket_pathname) { + struct sockaddr_un socket_address; + int socket_fd; + + socket_fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (socket_fd < 0) { + RAY_LOG(ERROR) << "socket() failed for pathname " << socket_pathname; + return -1; + } + + memset(&socket_address, 0, sizeof(socket_address)); + socket_address.sun_family = AF_UNIX; + if (socket_pathname.length() + 1 > sizeof(socket_address.sun_path)) { + RAY_LOG(ERROR) << "Socket pathname is too long."; + close(socket_fd); + return -1; + } + strncpy(socket_address.sun_path, socket_pathname.c_str(), socket_pathname.length() + 1); + + if (connect(socket_fd, (struct sockaddr *)&socket_address, sizeof(socket_address)) != + 0) { + close(socket_fd); + return -1; + } + return socket_fd; +} + +int read_bytes(int socket_fd, uint8_t *cursor, size_t length) { + ssize_t nbytes = 0; + // Termination condition: EOF or read 'length' bytes total. + size_t bytesleft = length; + size_t offset = 0; + while (bytesleft > 0) { + nbytes = read(socket_fd, cursor + offset, bytesleft); + if (nbytes < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { + continue; + } + return -1; // Errno will be set. + } else if (0 == nbytes) { + // Encountered early EOF. + return -1; + } + RAY_CHECK(nbytes > 0); + bytesleft -= nbytes; + offset += nbytes; + } + return 0; +} + +int write_bytes(int socket_fd, uint8_t *cursor, size_t length) { + ssize_t nbytes = 0; + size_t bytesleft = length; + size_t offset = 0; + while (bytesleft > 0) { + // While we haven't written the whole message, write to the file + // descriptor, advance the cursor, and decrease the amount left to write. + nbytes = write(socket_fd, cursor + offset, bytesleft); + if (nbytes < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { + continue; + } + return -1; // Errno will be set. + } else if (0 == nbytes) { + // Encountered early EOF. + return -1; + } + RAY_CHECK(nbytes > 0); + bytesleft -= nbytes; + offset += nbytes; + } + return 0; +} + +RayletConnection::RayletConnection(const std::string &raylet_socket, int num_retries, + int64_t timeout) { + // Pick the default values if the user did not specify. + if (num_retries < 0) { + num_retries = RayConfig::instance().num_connect_attempts(); + } + if (timeout < 0) { + timeout = RayConfig::instance().connect_timeout_milliseconds(); + } + RAY_CHECK(!raylet_socket.empty()); + conn_ = -1; + for (int num_attempts = 0; num_attempts < num_retries; ++num_attempts) { + conn_ = connect_ipc_sock(raylet_socket); + if (conn_ >= 0) break; + if (num_attempts > 0) { + RAY_LOG(ERROR) << "Retrying to connect to socket for pathname " << raylet_socket + << " (num_attempts = " << num_attempts + << ", num_retries = " << num_retries << ")"; + } + // Sleep for timeout milliseconds. + usleep(timeout * 1000); + } + // If we could not connect to the socket, exit. + if (conn_ == -1) { + RAY_LOG(FATAL) << "Could not connect to socket " << raylet_socket; + } +} + +ray::Status RayletConnection::Disconnect() { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateDisconnectClient(fbb); + fbb.Finish(message); + auto status = WriteMessage(MessageType::IntentionalDisconnectClient, &fbb); + // Don't be too strict for disconnection errors. + // Just create logs and prevent it from crash. + if (!status.ok()) { + RAY_LOG(ERROR) << status.ToString() + << " [RayletClient] Failed to disconnect from raylet."; + } + return ray::Status::OK(); +} + +ray::Status RayletConnection::ReadMessage(MessageType type, + std::unique_ptr &message) { + int64_t cookie; + int64_t type_field; + int64_t length; + int closed = read_bytes(conn_, (uint8_t *)&cookie, sizeof(cookie)); + if (closed) goto disconnected; + RAY_CHECK(cookie == RayConfig::instance().ray_cookie()); + closed = read_bytes(conn_, (uint8_t *)&type_field, sizeof(type_field)); + if (closed) goto disconnected; + closed = read_bytes(conn_, (uint8_t *)&length, sizeof(length)); + if (closed) goto disconnected; + message = std::unique_ptr(new uint8_t[length]); + closed = read_bytes(conn_, message.get(), length); + if (closed) { + // Handle the case in which the socket is closed. + message.reset(nullptr); + disconnected: + message = nullptr; + type_field = static_cast(MessageType::DisconnectClient); + length = 0; + } + if (type_field == static_cast(MessageType::DisconnectClient)) { + return ray::Status::IOError("[RayletClient] Raylet connection closed."); + } + if (type_field != static_cast(type)) { + return ray::Status::TypeError( + std::string("[RayletClient] Raylet connection corrupted. ") + + "Expected message type: " + std::to_string(static_cast(type)) + + "; got message type: " + std::to_string(type_field) + + ". Check logs or dmesg for previous errors."); + } + return ray::Status::OK(); +} + +ray::Status RayletConnection::WriteMessage(MessageType type, + flatbuffers::FlatBufferBuilder *fbb) { + std::unique_lock guard(write_mutex_); + int64_t cookie = RayConfig::instance().ray_cookie(); + int64_t length = fbb ? fbb->GetSize() : 0; + uint8_t *bytes = fbb ? fbb->GetBufferPointer() : nullptr; + int64_t type_field = static_cast(type); + auto io_error = ray::Status::IOError("[RayletClient] Connection closed unexpectedly."); + int closed; + closed = write_bytes(conn_, (uint8_t *)&cookie, sizeof(cookie)); + if (closed) return io_error; + closed = write_bytes(conn_, (uint8_t *)&type_field, sizeof(type_field)); + if (closed) return io_error; + closed = write_bytes(conn_, (uint8_t *)&length, sizeof(length)); + if (closed) return io_error; + closed = write_bytes(conn_, bytes, length * sizeof(char)); + if (closed) return io_error; + return ray::Status::OK(); +} + +ray::Status RayletConnection::AtomicRequestReply( + MessageType request_type, MessageType reply_type, + std::unique_ptr &reply_message, flatbuffers::FlatBufferBuilder *fbb) { + std::unique_lock guard(mutex_); + auto status = WriteMessage(request_type, fbb); + if (!status.ok()) return status; + return ReadMessage(reply_type, reply_message); +} + +RayletClient::RayletClient(const std::string &raylet_socket, const WorkerID &worker_id, + bool is_worker, const JobID &job_id, const Language &language, + int port) + : worker_id_(worker_id), is_worker_(is_worker), job_id_(job_id), language_(language) { + // For C++14, we could use std::make_unique + conn_ = std::unique_ptr(new RayletConnection(raylet_socket, -1, -1)); + + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateRegisterClientRequest( + fbb, is_worker, to_flatbuf(fbb, worker_id), getpid(), to_flatbuf(fbb, job_id), + language, port); + fbb.Finish(message); + // Register the process ID with the raylet. + // NOTE(swang): If raylet exits and we are registered as a worker, we will get killed. + auto status = conn_->WriteMessage(MessageType::RegisterClientRequest, &fbb); + RAY_CHECK_OK_PREPEND(status, "[RayletClient] Unable to register worker with raylet."); +} + +ray::Status RayletClient::SubmitTask(const ray::TaskSpecification &task_spec) { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateSubmitTaskRequest( + fbb, fbb.CreateString(task_spec.Serialize())); + fbb.Finish(message); + return conn_->WriteMessage(MessageType::SubmitTask, &fbb); +} + +ray::Status RayletClient::GetTask(std::unique_ptr *task_spec) { + std::unique_ptr reply; + // Receive a task from the raylet. This will block until the raylet + // gives this client a task. + auto status = + conn_->AtomicRequestReply(MessageType::GetTask, MessageType::ExecuteTask, reply); + if (!status.ok()) return status; + // Parse the flatbuffer object. + auto reply_message = flatbuffers::GetRoot(reply.get()); + // Set the resource IDs for this task. + resource_ids_.clear(); + for (size_t i = 0; i < reply_message->fractional_resource_ids()->size(); ++i) { + auto const &fractional_resource_ids = + reply_message->fractional_resource_ids()->Get(i); + auto &acquired_resources = + resource_ids_[string_from_flatbuf(*fractional_resource_ids->resource_name())]; + + size_t num_resource_ids = fractional_resource_ids->resource_ids()->size(); + size_t num_resource_fractions = fractional_resource_ids->resource_fractions()->size(); + RAY_CHECK(num_resource_ids == num_resource_fractions); + RAY_CHECK(num_resource_ids > 0); + for (size_t j = 0; j < num_resource_ids; ++j) { + int64_t resource_id = fractional_resource_ids->resource_ids()->Get(j); + double resource_fraction = fractional_resource_ids->resource_fractions()->Get(j); + if (num_resource_ids > 1) { + int64_t whole_fraction = resource_fraction; + RAY_CHECK(whole_fraction == resource_fraction); + } + acquired_resources.push_back(std::make_pair(resource_id, resource_fraction)); + } + } + + // Return the copy of the task spec and pass ownership to the caller. + task_spec->reset( + new ray::TaskSpecification(string_from_flatbuf(*reply_message->task_spec()))); + return ray::Status::OK(); +} + +ray::Status RayletClient::TaskDone() { + return conn_->WriteMessage(MessageType::TaskDone); +} + +ray::Status RayletClient::FetchOrReconstruct(const std::vector &object_ids, + bool fetch_only, + const TaskID ¤t_task_id) { + flatbuffers::FlatBufferBuilder fbb; + auto object_ids_message = to_flatbuf(fbb, object_ids); + auto message = ray::protocol::CreateFetchOrReconstruct( + fbb, object_ids_message, fetch_only, to_flatbuf(fbb, current_task_id)); + fbb.Finish(message); + auto status = conn_->WriteMessage(MessageType::FetchOrReconstruct, &fbb); + return status; +} + +ray::Status RayletClient::NotifyUnblocked(const TaskID ¤t_task_id) { + flatbuffers::FlatBufferBuilder fbb; + auto message = + ray::protocol::CreateNotifyUnblocked(fbb, to_flatbuf(fbb, current_task_id)); + fbb.Finish(message); + return conn_->WriteMessage(MessageType::NotifyUnblocked, &fbb); +} + +ray::Status RayletClient::Wait(const std::vector &object_ids, int num_returns, + int64_t timeout_milliseconds, bool wait_local, + const TaskID ¤t_task_id, WaitResultPair *result) { + // Write request. + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateWaitRequest( + fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds, wait_local, + to_flatbuf(fbb, current_task_id)); + fbb.Finish(message); + std::unique_ptr reply; + auto status = conn_->AtomicRequestReply(MessageType::WaitRequest, + MessageType::WaitReply, reply, &fbb); + if (!status.ok()) return status; + // Parse the flatbuffer object. + auto reply_message = flatbuffers::GetRoot(reply.get()); + auto found = reply_message->found(); + for (uint i = 0; i < found->size(); i++) { + ObjectID object_id = ObjectID::FromBinary(found->Get(i)->str()); + result->first.push_back(object_id); + } + auto remaining = reply_message->remaining(); + for (uint i = 0; i < remaining->size(); i++) { + ObjectID object_id = ObjectID::FromBinary(remaining->Get(i)->str()); + result->second.push_back(object_id); + } + return ray::Status::OK(); +} + +ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string &type, + const std::string &error_message, double timestamp) { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreatePushErrorRequest( + fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type), + fbb.CreateString(error_message), timestamp); + fbb.Finish(message); + + return conn_->WriteMessage(MessageType::PushErrorRequest, &fbb); +} + +ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_events) { + flatbuffers::FlatBufferBuilder fbb; + auto message = fbb.CreateString(profile_events.SerializeAsString()); + fbb.Finish(message); + + auto status = conn_->WriteMessage(MessageType::PushProfileEventsRequest, &fbb); + // Don't be too strict for profile errors. Just create logs and prevent it from crash. + if (!status.ok()) { + RAY_LOG(ERROR) << status.ToString() + << " [RayletClient] Failed to push profile events."; + } + return ray::Status::OK(); +} + +ray::Status RayletClient::FreeObjects(const std::vector &object_ids, + bool local_only, bool delete_creating_tasks) { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateFreeObjectsRequest( + fbb, local_only, delete_creating_tasks, to_flatbuf(fbb, object_ids)); + fbb.Finish(message); + + auto status = conn_->WriteMessage(MessageType::FreeObjectsInObjectStoreRequest, &fbb); + return status; +} + +ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id, + ActorCheckpointID &checkpoint_id) { + flatbuffers::FlatBufferBuilder fbb; + auto message = + ray::protocol::CreatePrepareActorCheckpointRequest(fbb, to_flatbuf(fbb, actor_id)); + fbb.Finish(message); + + std::unique_ptr reply; + auto status = + conn_->AtomicRequestReply(MessageType::PrepareActorCheckpointRequest, + MessageType::PrepareActorCheckpointReply, reply, &fbb); + if (!status.ok()) return status; + auto reply_message = + flatbuffers::GetRoot(reply.get()); + checkpoint_id = ActorCheckpointID::FromBinary(reply_message->checkpoint_id()->str()); + return ray::Status::OK(); +} + +ray::Status RayletClient::NotifyActorResumedFromCheckpoint( + const ActorID &actor_id, const ActorCheckpointID &checkpoint_id) { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateNotifyActorResumedFromCheckpoint( + fbb, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, checkpoint_id)); + fbb.Finish(message); + + return conn_->WriteMessage(MessageType::NotifyActorResumedFromCheckpoint, &fbb); +} + +ray::Status RayletClient::SetResource(const std::string &resource_name, + const double capacity, + const ray::ClientID &client_Id) { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateSetResourceRequest( + fbb, fbb.CreateString(resource_name), capacity, to_flatbuf(fbb, client_Id)); + fbb.Finish(message); + return conn_->WriteMessage(MessageType::SetResourceRequest, &fbb); +} diff --git a/src/ray/rpc/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h similarity index 70% rename from src/ray/rpc/raylet/raylet_client.h rename to src/ray/raylet/raylet_client.h index 6960b0f8e..235ba9cfb 100644 --- a/src/ray/rpc/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -1,26 +1,18 @@ -#ifndef RAY_RPC_RAYLET_CLIENT_H -#define RAY_RPC_RAYLET_CLIENT_H +#ifndef RAYLET_CLIENT_H +#define RAYLET_CLIENT_H -#include #include #include -#include #include -#include #include #include -#include - #include "ray/common/status.h" #include "ray/common/task/task_spec.h" -#include "src/ray/common/status.h" -#include "src/ray/protobuf/raylet.grpc.pb.h" -#include "src/ray/protobuf/raylet.pb.h" -#include "src/ray/rpc/client_call.h" using ray::ActorCheckpointID; using ray::ActorID; +using ray::ClientID; using ray::JobID; using ray::ObjectID; using ray::TaskID; @@ -28,39 +20,65 @@ using ray::WorkerID; using ray::Language; using ray::rpc::ProfileTableData; -using WaitResultPair = std::pair, std::vector>; - -namespace ray { -namespace rpc { +using MessageType = ray::protocol::MessageType; using ResourceMappingType = std::unordered_map>>; +using WaitResultPair = std::pair, std::vector>; + +class RayletConnection { + public: + /// Connect to the raylet. + /// + /// \param raylet_socket The name of the socket to use to connect to the raylet. + /// \param worker_id A unique ID to represent the worker. + /// \param is_worker Whether this client is a worker. If it is a worker, an + /// additional message will be sent to register as one. + /// \param job_id The ID of the driver. This is non-nil if the client is a + /// driver. + /// \return The connection information. + RayletConnection(const std::string &raylet_socket, int num_retries, int64_t timeout); + + ~RayletConnection() { close(conn_); } + /// Notify the raylet that this client is disconnecting gracefully. This + /// is used by actors to exit gracefully so that the raylet doesn't + /// propagate an error message to the driver. + /// + /// \return ray::Status. + ray::Status Disconnect(); + ray::Status ReadMessage(MessageType type, std::unique_ptr &message); + ray::Status WriteMessage(MessageType type, + flatbuffers::FlatBufferBuilder *fbb = nullptr); + ray::Status AtomicRequestReply(MessageType request_type, MessageType reply_type, + std::unique_ptr &reply_message, + flatbuffers::FlatBufferBuilder *fbb = nullptr); + + private: + /// File descriptor of the Unix domain socket that connects to raylet. + int conn_; + /// A mutex to protect stateful operations of the raylet client. + std::mutex mutex_; + /// A mutex to protect write operations of the raylet client. + std::mutex write_mutex_; +}; -/// Client used for communicating with the raylet. class RayletClient { public: - /// Constructor for the raylet client. - /// TODO(jzh): At present, client call manager and reply handler service are generated - /// in raylet client. Instead, we should add parameters to the constructor and pass them - /// in. Change them as input parameters once we changed the worker into a server. + /// Connect to the raylet. /// - /// \param[in] raylet_socket Unix domain socket of the raylet server. - /// \param[in] worker_id The worker id. - /// \param[in] is_worker Indicates whether a worker or a driver. - /// \param[in] job_id The job id that this raylet client belongs to. - /// \param[in] language The language type, python or java. - /// \param[in] port The listening port of the worker server, -1 means that the worker - /// does not have a server. + /// \param raylet_socket The name of the socket to use to connect to the raylet. + /// \param worker_id A unique ID to represent the worker. + /// \param is_worker Whether this client is a worker. If it is a worker, an + /// additional message will be sent to register as one. + /// \param job_id The ID of the driver. This is non-nil if the client is a driver. + /// \return The connection information. RayletClient(const std::string &raylet_socket, const WorkerID &worker_id, bool is_worker, const JobID &job_id, const Language &language, int port = -1); - ~RayletClient(); + ray::Status Disconnect() { return conn_->Disconnect(); }; - /// Send disconnect request to local raylet. - ray::Status Disconnect(); - - /// Submit a task to the local raylet. + /// Submit a task using the raylet code path. /// /// \param The task specification. /// \return ray::Status. @@ -160,7 +178,7 @@ class RayletClient { Language GetLanguage() const { return language_; } - WorkerID GetWorkerId() const { return worker_id_; } + WorkerID GetWorkerID() const { return worker_id_; } JobID GetJobID() const { return job_id_; } @@ -169,53 +187,16 @@ class RayletClient { const ResourceMappingType &GetResourceIDs() const { return resource_ids_; } private: - /// Try to register client in raylet, we would retry serveral time to - /// reconnect if failed. We need this because raylet client may start before raylet - /// server. - /// - /// \param times Number of times to retry. - void TryRegisterClient(int retry_times); - - ray::Status RegisterClient(); - - /// Send heartbeat requests to the raylet server. - void Heartbeat(); - /// Id of the worker to which this raylet client belongs. const WorkerID worker_id_; - /// Indicates whether this worker is a driver worker. - /// Driver is treated as a special worker. const bool is_worker_; const JobID job_id_; const Language language_; - const int port_; /// A map from resource name to the resource IDs that are currently reserved /// for this worker. Each pair consists of the resource ID and the fraction /// of that resource allocated for this worker. ResourceMappingType resource_ids_; - - /// The gRPC-generated stub. - std::unique_ptr stub_; - - /// Service for handling reply. - boost::asio::io_service main_service_; - - /// Asio work for main service. - boost::asio::io_service::work work_; - - /// The `ClientCallManager` used for managing requests. - ClientCallManager client_call_manager_; - - /// The thread used to handle reply. - std::thread rpc_thread_; - - /// Heartbeat timer. - boost::asio::deadline_timer heartbeat_timer_; - - /// Indicates whether the connection has been closed. - bool is_connected_; + /// The connection to the raylet server. + std::unique_ptr conn_; }; -} // namespace rpc -} // namespace ray - -#endif // RAY_RPC_RAYLET_CLIENT_H +#endif diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index ab109ccbf..2052868b3 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -5,6 +5,7 @@ #include +#include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/reconstruction_policy.h" #include "ray/object_manager/object_directory.h" @@ -417,7 +418,6 @@ TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { auto task_reconstruction_data = std::make_shared(); task_reconstruction_data->set_node_manager_id(ClientID::FromRandom().Binary()); task_reconstruction_data->set_num_reconstructions(0); - RAY_CHECK_OK( mock_gcs_.AppendAt(JobID::Nil(), task_id, task_reconstruction_data, nullptr, /*failure_callback=*/ diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index dc4e5d4ab..7effa44ed 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -3,7 +3,6 @@ // clang-format off #include "ray/common/id.h" -#include "ray/common/ray_config.h" #include "ray/common/task/task.h" #include "ray/object_manager/object_manager.h" #include "ray/raylet/reconstruction_policy.h" diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 49624d4bd..52814bb70 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -2,6 +2,7 @@ #include +#include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/raylet.h" namespace ray { @@ -10,27 +11,25 @@ namespace raylet { /// A constructor responsible for initializing the state of a worker. Worker::Worker(const WorkerID &worker_id, pid_t pid, const Language &language, int port, - rpc::ClientCallManager &client_call_manager, bool is_worker) + std::shared_ptr connection, + rpc::ClientCallManager &client_call_manager) : worker_id_(worker_id), pid_(pid), - port_(port), language_(language), + port_(port), + connection_(connection), + dead_(false), blocked_(false), - num_missed_heartbeats_(0), - is_being_killed_(false), - client_call_manager_(client_call_manager), - is_worker_(is_worker) { + client_call_manager_(client_call_manager) { if (port_ > 0) { rpc_client_ = std::unique_ptr( new rpc::WorkerTaskClient("127.0.0.1", port_, client_call_manager_)); } } -void Worker::MarkAsBeingKilled() { is_being_killed_ = true; } +void Worker::MarkDead() { dead_ = true; } -bool Worker::IsBeingKilled() const { return is_being_killed_; } - -bool Worker::IsWorker() const { return is_worker_; } +bool Worker::IsDead() const { return dead_; } void Worker::MarkBlocked() { blocked_ = true; } @@ -44,8 +43,6 @@ pid_t Worker::Pid() const { return pid_; } Language Worker::GetLanguage() const { return language_; } -const WorkerID &Worker::GetWorkerId() const { return worker_id_; } - int Worker::Port() const { return port_; } void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id; } @@ -79,6 +76,10 @@ void Worker::AssignActorId(const ActorID &actor_id) { const ActorID &Worker::GetActorId() const { return actor_id_; } +const std::shared_ptr Worker::Connection() const { + return connection_; +} + const ResourceIdSet &Worker::GetLifetimeResourceIds() const { return lifetime_resource_ids_; } @@ -112,16 +113,10 @@ void Worker::AcquireTaskCpuResources(const ResourceIdSet &cpu_resources) { task_resource_ids_.Release(cpu_resources); } -void Worker::SetGetTaskReplyAndCallback( - rpc::GetTaskReply *reply, const rpc::SendReplyCallback &&send_reply_callback) { - RAY_CHECK(reply_ == nullptr && send_reply_callback_ == nullptr); - reply_ = reply; - send_reply_callback_ = std::move(send_reply_callback); -} - bool Worker::UsePush() const { return rpc_client_ != nullptr; } -void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set) { +void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set, + const std::function finish_assign_callback) { const TaskSpecification &spec = task.GetTaskSpecification(); if (rpc_client_ != nullptr) { // Use push mode. @@ -131,16 +126,15 @@ void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set) task.GetTaskSpecification().GetMessage()); request.mutable_task()->mutable_task_execution_spec()->CopyFrom( task.GetTaskExecutionSpec().GetMessage()); - for (const auto &e : resource_id_set.ToProtobuf()) { - auto resource = request.add_resource_ids(); - *resource = e; - } + request.set_resource_ids(resource_id_set.Serialize()); + auto status = rpc_client_->AssignTask( request, [](Status status, const rpc::AssignTaskReply &reply) { // Worker has finished this task. There's nothing to do here // and assigning new task will be done when raylet receives // `TaskDone` message. }); + finish_assign_callback(status); if (!status.ok()) { RAY_LOG(ERROR) << "Failed to assign task " << task.GetTaskSpecification().TaskId() << " to worker " << worker_id_; @@ -151,19 +145,17 @@ void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set) } else { // Use pull mode. This corresponds to existing python/java workers that haven't been // migrated to core worker architecture. - RAY_CHECK(reply_ != nullptr && send_reply_callback_ != nullptr); - reply_->set_task_spec(task.GetTaskSpecification().Serialize()); - for (const auto &e : resource_id_set.ToProtobuf()) { - auto resource = reply_->add_fractional_resource_ids(); - *resource = e; - } - send_reply_callback_(Status::OK(), nullptr, nullptr); - reply_ = nullptr; - send_reply_callback_ = nullptr; + flatbuffers::FlatBufferBuilder fbb; + auto resource_id_set_flatbuf = resource_id_set.ToFlatbuf(fbb); + + auto message = + protocol::CreateGetTaskReply(fbb, fbb.CreateString(spec.Serialize()), + fbb.CreateVector(resource_id_set_flatbuf)); + fbb.Finish(message); + Connection()->WriteMessageAsync( + static_cast(protocol::MessageType::ExecuteTask), fbb.GetSize(), + fbb.GetBufferPointer(), finish_assign_callback); } - // The status will be cleared when the worker dies. - AssignTaskId(spec.TaskId()); - AssignJobId(spec.JobId()); } } // namespace raylet diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index 3d77da15b..93531ac8a 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -3,15 +3,12 @@ #include +#include "ray/common/client_connection.h" #include "ray/common/id.h" #include "ray/common/task/scheduling_resources.h" #include "ray/common/task/task.h" #include "ray/common/task/task_common.h" -#include "ray/protobuf/common.pb.h" #include "ray/rpc/worker/worker_client.h" -#include "src/ray/protobuf/gcs.pb.h" -#include "src/ray/protobuf/raylet.pb.h" -#include "src/ray/rpc/server_call.h" namespace ray { @@ -24,12 +21,12 @@ class Worker { public: /// A constructor that initializes a worker object. Worker(const WorkerID &worker_id, pid_t pid, const Language &language, int port, - rpc::ClientCallManager &client_call_manager, bool is_worker = true); + std::shared_ptr connection, + rpc::ClientCallManager &client_call_manager); /// A destructor responsible for freeing all worker state. ~Worker() {} - void MarkAsBeingKilled(); - bool IsBeingKilled() const; - bool IsWorker() const; + void MarkDead(); + bool IsDead() const; void MarkBlocked(); void MarkUnblocked(); bool IsBlocked() const; @@ -38,7 +35,6 @@ class Worker { /// Return the worker's PID. pid_t Pid() const; Language GetLanguage() const; - const WorkerID &GetWorkerId() const; int Port() const; void AssignTaskId(const TaskID &task_id); const TaskID &GetAssignedTaskId() const; @@ -49,6 +45,7 @@ class Worker { const JobID &GetAssignedJobId() const; void AssignActorId(const ActorID &actor_id); const ActorID &GetActorId() const; + const std::shared_ptr Connection() const; const ResourceIdSet &GetLifetimeResourceIds() const; void SetLifetimeResourceIds(ResourceIdSet &resource_ids); @@ -60,34 +57,30 @@ class Worker { ResourceIdSet ReleaseTaskCpuResources(); void AcquireTaskCpuResources(const ResourceIdSet &cpu_resources); - int TickHeartbeatTimer() { return ++num_missed_heartbeats_; } - void ClearHeartbeat() { num_missed_heartbeats_ = 0; } - - /// When receiving a `GetTask` request from worker, this function should be called to - /// pass the reply and callback of the request to the worker. Later, when we actually - /// assign a task to the worker, the reply will be filled and the callback will be - /// called. - void SetGetTaskReplyAndCallback(rpc::GetTaskReply *reply, - const rpc::SendReplyCallback &&send_reply_callback); - bool UsePush() const; - void AssignTask(const Task &task, const ResourceIdSet &resource_id_set); + void AssignTask(const Task &task, const ResourceIdSet &resource_id_set, + const std::function finish_assign_callback); private: /// The worker's ID. WorkerID worker_id_; /// The worker's PID. pid_t pid_; - /// The worker port. - int port_; /// The language type of this worker. Language language_; + /// Port that this worker listens on. + /// If port <= 0, this indicates that the worker will not listen to a port. + int port_; + /// Connection state of a worker. + std::shared_ptr connection_; /// The worker's currently assigned task. TaskID assigned_task_id_; /// Job ID for the worker's current assigned task. JobID assigned_job_id_; /// The worker's actor ID. If this is nil, then the worker is not an actor. ActorID actor_id_; + /// Whether the worker is dead. + bool dead_; /// Whether the worker is blocked. Workers become blocked in a `ray.get`, if /// they require a data dependency while executing a task. bool blocked_; @@ -98,22 +91,11 @@ class Worker { // of a task. ResourceIdSet task_resource_ids_; std::unordered_set blocked_task_ids_; - /// How many heartbeats have been missed for this worker. - int num_missed_heartbeats_; - /// Indicates we have sent kill signal to the worker if it's true. We cannot treat the - /// worker process as really dead until we lost the heartbeats from the worker. - bool is_being_killed_; /// The `ClientCallManager` object that is shared by `WorkerTaskClient` from all /// workers. rpc::ClientCallManager &client_call_manager_; - /// Indicates whether this is a worker or a driver. - bool is_worker_; /// The rpc client to send tasks to this worker. std::unique_ptr rpc_client_; - /// Reply of the `GetTask` request. - rpc::GetTaskReply *reply_ = nullptr; - /// Callback of the `GetTask` request. - rpc::SendReplyCallback send_reply_callback_ = nullptr; }; } // namespace raylet diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 4d49d31dc..f2e77c6cf 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -12,6 +12,29 @@ #include "ray/util/logging.h" #include "ray/util/util.h" +namespace { + +// A helper function to get a worker from a list. +std::shared_ptr GetWorker( + const std::unordered_set> &worker_pool, + const std::shared_ptr &connection) { + for (auto it = worker_pool.begin(); it != worker_pool.end(); it++) { + if ((*it)->Connection() == connection) { + return (*it); + } + } + return nullptr; +} + +// A helper function to remove a worker from a list. Returns true if the worker +// was found and removed. +bool RemoveWorker(std::unordered_set> &worker_pool, + const std::shared_ptr &worker) { + return worker_pool.erase(worker) > 0; +} + +} // namespace + namespace ray { namespace raylet { @@ -50,8 +73,8 @@ WorkerPool::~WorkerPool() { for (const auto &entry : states_by_lang_) { // Kill all registered workers. NOTE(swang): This assumes that the registered // workers were started by the pool. - for (const auto &worker_pair : entry.second.registered_workers) { - pids_to_kill.insert(worker_pair.second->Pid()); + for (const auto &worker : entry.second.registered_workers) { + pids_to_kill.insert(worker->Pid()); } // Kill all the workers that have been started but not registered. for (const auto &starting_worker : entry.second.starting_worker_processes) { @@ -166,8 +189,7 @@ pid_t WorkerPool::StartProcess(const std::vector &worker_command_ar return 0; } -Status WorkerPool::RegisterWorker(const WorkerID &worker_id, - const std::shared_ptr &worker) { +Status WorkerPool::RegisterWorker(const std::shared_ptr &worker) { const auto pid = worker->Pid(); const auto port = worker->Port(); RAY_LOG(DEBUG) << "Registering worker with pid " << pid << ", port: " << port; @@ -183,35 +205,34 @@ Status WorkerPool::RegisterWorker(const WorkerID &worker_id, state.starting_worker_processes.erase(it); } - state.registered_workers.emplace(worker_id, std::move(worker)); + state.registered_workers.emplace(std::move(worker)); return Status::OK(); } -Status WorkerPool::RegisterDriver(const WorkerID &driver_id, - const std::shared_ptr &driver) { +Status WorkerPool::RegisterDriver(const std::shared_ptr &driver) { RAY_CHECK(!driver->GetAssignedTaskId().IsNil()); auto &state = GetStateForLanguage(driver->GetLanguage()); - state.registered_drivers.emplace(driver_id, std::move(driver)); + state.registered_drivers.insert(std::move(driver)); return Status::OK(); } -std::shared_ptr WorkerPool::GetRegisteredWorker(const WorkerID &worker_id) const { +std::shared_ptr WorkerPool::GetRegisteredWorker( + const std::shared_ptr &connection) const { for (const auto &entry : states_by_lang_) { - auto ®istered_workers = entry.second.registered_workers; - auto it = registered_workers.find(worker_id); - if (it != registered_workers.end()) { - return it->second; + auto worker = GetWorker(entry.second.registered_workers, connection); + if (worker != nullptr) { + return worker; } } return nullptr; } -std::shared_ptr WorkerPool::GetRegisteredDriver(const WorkerID &worker_id) const { +std::shared_ptr WorkerPool::GetRegisteredDriver( + const std::shared_ptr &connection) const { for (const auto &entry : states_by_lang_) { - auto ®istered_drivers = entry.second.registered_drivers; - auto it = registered_drivers.find(worker_id); - if (it != registered_drivers.end()) { - return it->second; + auto driver = GetWorker(entry.second.registered_drivers, connection); + if (driver != nullptr) { + return driver; } } return nullptr; @@ -295,20 +316,18 @@ std::shared_ptr WorkerPool::PopWorker(const TaskSpecification &task_spec bool WorkerPool::DisconnectWorker(const std::shared_ptr &worker) { auto &state = GetStateForLanguage(worker->GetLanguage()); - RAY_CHECK(state.registered_workers.erase(worker->GetWorkerId())); + RAY_CHECK(RemoveWorker(state.registered_workers, worker)); stats::CurrentWorker().Record( 0, {{stats::LanguageKey, Language_Name(worker->GetLanguage())}, {stats::WorkerPidKey, std::to_string(worker->Pid())}}); - // Indicates that we disconnect a idle worker successfully. - return (state.idle.erase(worker) > 0); + return RemoveWorker(state.idle, worker); } void WorkerPool::DisconnectDriver(const std::shared_ptr &driver) { auto &state = GetStateForLanguage(driver->GetLanguage()); - RAY_CHECK(state.registered_drivers.erase(driver->GetWorkerId())); - + RAY_CHECK(RemoveWorker(state.registered_drivers, driver)); stats::CurrentDriver().Record( 0, {{stats::LanguageKey, Language_Name(driver->GetLanguage())}, {stats::WorkerPidKey, std::to_string(driver->Pid())}}); @@ -325,8 +344,7 @@ std::vector> WorkerPool::GetWorkersRunningTasksForJob( std::vector> workers; for (const auto &entry : states_by_lang_) { - for (const auto &worker_pair : entry.second.registered_workers) { - auto &worker = worker_pair.second; + for (const auto &worker : entry.second.registered_workers) { if (worker->GetAssignedJobId() == job_id) { workers.push_back(worker); } @@ -381,16 +399,14 @@ std::string WorkerPool::DebugString() const { void WorkerPool::RecordMetrics() const { for (const auto &entry : states_by_lang_) { // Record worker. - for (auto worker_pair : entry.second.registered_workers) { - auto &worker = worker_pair.second; + for (auto worker : entry.second.registered_workers) { stats::CurrentWorker().Record( worker->Pid(), {{stats::LanguageKey, Language_Name(worker->GetLanguage())}, {stats::WorkerPidKey, std::to_string(worker->Pid())}}); } // Record driver. - for (auto driver_pair : entry.second.registered_drivers) { - auto &driver = driver_pair.second; + for (auto driver : entry.second.registered_drivers) { stats::CurrentDriver().Record( driver->Pid(), {{stats::LanguageKey, Language_Name(driver->GetLanguage())}, {stats::WorkerPidKey, std::to_string(driver->Pid())}}); @@ -398,38 +414,6 @@ void WorkerPool::RecordMetrics() const { } } -void WorkerPool::TickHeartbeatTimer(int max_missed_heartbeats, - std::vector> *dead_workers) { - for (const auto &entry : states_by_lang_) { - // Worker heartbeat. - for (const auto &worker_pair : entry.second.registered_workers) { - auto &worker = worker_pair.second; - if (worker->TickHeartbeatTimer() >= max_missed_heartbeats) { - dead_workers->emplace_back(worker); - } - } - - // Driver heartbeat. - for (const auto &driver_pair : entry.second.registered_drivers) { - auto &driver = driver_pair.second; - if (driver->TickHeartbeatTimer() >= max_missed_heartbeats) { - dead_workers->emplace_back(driver); - } - } - } -} - -std::shared_ptr WorkerPool::GetWorker(const WorkerID &worker_id) { - auto worker = GetRegisteredWorker(worker_id); - if (!worker) { - worker = GetRegisteredDriver(worker_id); - if (!worker) { - return nullptr; - } - } - return worker; -} - } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 44d8c5714..4d6f4b307 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -52,28 +52,29 @@ class WorkerPool { /// /// \param The Worker to be registered. /// \return If the registration is successful. - Status RegisterWorker(const WorkerID &worker_id, const std::shared_ptr &worker); + Status RegisterWorker(const std::shared_ptr &worker); /// Register a new driver. - /// Driver is a treated as a special worker, so use WorkerID as key here. /// /// \param The driver to be registered. /// \return If the registration is successful. - Status RegisterDriver(const WorkerID &driver_id, const std::shared_ptr &worker); + Status RegisterDriver(const std::shared_ptr &worker); /// Get the client connection's registered worker. /// /// \param The client connection owned by a registered worker. /// \return The Worker that owns the given client connection. Returns nullptr /// if the client has not registered a worker yet. - std::shared_ptr GetRegisteredWorker(const WorkerID &worker_id) const; + std::shared_ptr GetRegisteredWorker( + const std::shared_ptr &connection) const; /// Get the client connection's registered driver. /// /// \param The client connection owned by a registered driver. /// \return The Worker that owns the given client connection. Returns nullptr /// if the client has not registered a driver. - std::shared_ptr GetRegisteredDriver(const WorkerID &driver_id) const; + std::shared_ptr GetRegisteredDriver( + const std::shared_ptr &connection) const; /// Disconnect a registered worker. /// @@ -130,21 +131,6 @@ class WorkerPool { /// Record metrics. void RecordMetrics() const; - /// Tick the heartbeat timer and get the workers that have timed out. - /// A worker which has missed `max_missed_heartbeats` times would be treated as a - /// dead process or the network to it has been down. - /// - /// \param[in] max_missed_heartbeats The maximum number of heartbeats that can be - /// missed before a worker times out. - /// \param[out] dead_workers Workers whose processes have been dead. - void TickHeartbeatTimer(int max_missed_heartbeats, - std::vector> *dead_workers); - - /// Return the pointer to the worker according to the worker id. - /// - /// \param worker_id The worker id. - std::shared_ptr GetWorker(const WorkerID &worker_id); - protected: /// Asynchronously start a new worker process. Once the worker process has /// registered with an external server, the process should create and @@ -182,9 +168,9 @@ class WorkerPool { std::unordered_map> idle_actor; /// All workers that have registered and are still connected, including both /// idle and executing. - std::unordered_map> registered_workers; + std::unordered_set> registered_workers; /// All drivers that have registered and are still connected. - std::unordered_map> registered_drivers; + std::unordered_set> registered_drivers; /// A map from the pids of starting worker processes /// to the number of their unregistered workers. std::unordered_map starting_worker_processes; diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 3a4c4ad14..34ec1619a 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -71,9 +71,19 @@ class WorkerPoolTest : public ::testing::Test { std::shared_ptr CreateWorker(pid_t pid, const Language &language = Language::PYTHON) { - WorkerID worker_id = WorkerID::FromRandom(); - return std::shared_ptr(new Worker( - worker_id, pid, language, /* listening port */ -1, client_call_manager_)); + std::function client_handler = + [this](LocalClientConnection &client) { HandleNewClient(client); }; + std::function, int64_t, const uint8_t *)> + message_handler = [this](std::shared_ptr client, + int64_t message_type, const uint8_t *message) { + HandleMessage(client, message_type, message); + }; + boost::asio::local::stream_protocol::socket socket(io_service_); + auto client = + LocalClientConnection::Create(client_handler, message_handler, std::move(socket), + "worker", {}, error_message_type_); + return std::shared_ptr(new Worker(WorkerID::FromRandom(), pid, language, -1, + client, client_call_manager_)); } void SetWorkerCommands(const WorkerCommandMap &worker_commands) { @@ -86,6 +96,10 @@ class WorkerPoolTest : public ::testing::Test { boost::asio::io_service io_service_; int64_t error_message_type_; rpc::ClientCallManager client_call_manager_; + + private: + void HandleNewClient(LocalClientConnection &){}; + void HandleMessage(std::shared_ptr, int64_t, const uint8_t *){}; }; static inline TaskSpecification ExampleTaskSpec( @@ -117,23 +131,21 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) { workers.push_back(CreateWorker(pid)); } for (const auto &worker : workers) { - WorkerID worker_id = worker->GetWorkerId(); // Check that there's still a starting worker process // before all workers have been registered ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), 1); // Check that we cannot lookup the worker before it's registered. - ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker_id), nullptr); - RAY_CHECK_OK(worker_pool_.RegisterWorker(worker_id, worker)); + ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), nullptr); + RAY_CHECK_OK(worker_pool_.RegisterWorker(worker)); // Check that we can lookup the worker after it's registered. - ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker_id), worker); + ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), worker); } // Check that there's no starting worker process ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), 0); for (const auto &worker : workers) { - WorkerID worker_id = worker->GetWorkerId(); worker_pool_.DisconnectWorker(worker); // Check that we cannot lookup the worker after it's disconnected. - ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker_id), nullptr); + ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), nullptr); } } diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index bac98a583..31466dc17 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -6,20 +6,14 @@ namespace ray { namespace rpc { void GrpcServer::Run() { - std::string server_address; - // Set unix domain socket or tcp address. - if (!unix_socket_path_.empty()) { - server_address = "unix://" + unix_socket_path_; - } else { - server_address = "0.0.0.0:" + std::to_string(port_); - } + std::string server_address("0.0.0.0:" + std::to_string(port_)); grpc::ServerBuilder builder; // TODO(hchen): Add options for authentication. builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); // Register all the services to this server. if (services_.empty()) { - RAY_LOG(WARNING) << "No service found when start grpc server " << name_; + RAY_LOG(WARNING) << "No service is found when start grpc server " << name_; } for (auto &entry : services_) { builder.RegisterService(&entry.get()); @@ -29,11 +23,7 @@ void GrpcServer::Run() { cq_ = builder.AddCompletionQueue(); // Build and start server. server_ = builder.BuildAndStart(); - if (unix_socket_path_.empty()) { - // For a TCP-based server, the actual port is decided after `AddListeningPort`. - server_address = "0.0.0.0:" + std::to_string(port_); - } - RAY_LOG(INFO) << name_ << " server started, listening on " << server_address; + RAY_LOG(INFO) << name_ << " server started, listening on port " << port_ << "."; // Create calls for all the server call factories. for (auto &entry : server_call_factories_and_concurrencies_) { diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 727301329..b2e884445 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -32,16 +32,7 @@ class GrpcServer { /// \param[in] port The port to bind this server to. If it's 0, a random available port /// will be chosen. GrpcServer(std::string name, const uint32_t port) - : name_(std::move(name)), port_(port), is_closed_(false) {} - - /// Construct a gRPC server that listens on unix domain socket. - /// - /// \param[in] name Name of this server, used for logging and debugging purpose. - /// \param[in] unix_socket_path Unix domain socket full path. - GrpcServer(std::string name, const std::string &unix_socket_path) - : GrpcServer(std::move(name), 0) { - unix_socket_path_ = unix_socket_path; - } + : name_(std::move(name)), port_(port), is_closed_(true) {} /// Destruct this gRPC server. ~GrpcServer() { Shutdown(); } @@ -82,8 +73,6 @@ class GrpcServer { int port_; /// Indicates whether this server has been closed. bool is_closed_; - /// Unix domain socket path. - std::string unix_socket_path_; /// The `grpc::Service` objects which should be registered to `ServerBuilder`. std::vector> services_; /// The `ServerCallFactory` objects, and the maximum number of concurrent requests that diff --git a/src/ray/rpc/node_manager/node_manager_client.h b/src/ray/rpc/node_manager/node_manager_client.h index 5d7c915bf..005c75db4 100644 --- a/src/ray/rpc/node_manager/node_manager_client.h +++ b/src/ray/rpc/node_manager/node_manager_client.h @@ -5,7 +5,6 @@ #include -#include "ray/common/grpc_util.h" #include "ray/common/status.h" #include "ray/rpc/client_call.h" #include "ray/util/logging.h" diff --git a/src/ray/rpc/object_manager/object_manager_client.h b/src/ray/rpc/object_manager/object_manager_client.h index 9c9921503..f37a081e6 100644 --- a/src/ray/rpc/object_manager/object_manager_client.h +++ b/src/ray/rpc/object_manager/object_manager_client.h @@ -5,7 +5,6 @@ #include -#include "ray/common/grpc_util.h" #include "ray/common/status.h" #include "ray/util/logging.h" #include "src/ray/protobuf/object_manager.grpc.pb.h" diff --git a/src/ray/rpc/raylet/raylet_client.cc b/src/ray/rpc/raylet/raylet_client.cc deleted file mode 100644 index 80636e858..000000000 --- a/src/ray/rpc/raylet/raylet_client.cc +++ /dev/null @@ -1,440 +0,0 @@ - -#include "src/ray/rpc/raylet/raylet_client.h" - -namespace ray { -namespace rpc { - -#define RETURN_IF_DISCONNECTED(connected) \ - do { \ - if (!(connected)) { \ - return Status::Invalid("Raylet connection is closed."); \ - } \ - } while (0) - -RayletClient::RayletClient(const std::string &raylet_socket, const WorkerID &worker_id, - bool is_worker, const JobID &job_id, const Language &language, - int port) - : worker_id_(worker_id), - is_worker_(is_worker), - job_id_(job_id), - language_(language), - port_(port), - main_service_(), - work_(main_service_), - client_call_manager_(main_service_), - heartbeat_timer_(main_service_), - is_connected_(false) { - std::shared_ptr channel = - grpc::CreateChannel("unix://" + raylet_socket, grpc::InsecureChannelCredentials()); - stub_ = RayletService::NewStub(channel); - - rpc_thread_ = std::thread([this]() { main_service_.run(); }); - RAY_LOG(DEBUG) << "Connecting to unix socket: " - << "unix://" + raylet_socket - << ", is worker: " << (is_worker_ ? "true" : "false") - << ", worker id: " << worker_id; - // Try to register client `num_raylet_client_retry_times` times. - TryRegisterClient(RayConfig::instance().num_raylet_client_retry_times()); -} - -void RayletClient::TryRegisterClient(int retry_times) { - // We should block here until register succeeds. - for (int i = 0; i < retry_times; i++) { - auto st = RegisterClient(); - if (st.ok()) { - is_connected_ = true; - Heartbeat(); - return; - } - std::this_thread::sleep_for(std::chrono::milliseconds(500)); - } - RAY_LOG(FATAL) << "Worker " << worker_id_ - << " failed to register to raylet server, worker id: " << worker_id_ - << ", pid: " << static_cast(getpid()) - << ", is worker: " << is_worker_; -} - -RayletClient::~RayletClient() { - is_connected_ = false; - main_service_.stop(); - rpc_thread_.join(); -} - -ray::Status RayletClient::Disconnect() { - DisconnectClientRequest disconnect_client_request; - disconnect_client_request.set_worker_id(worker_id_.Binary()); - - DisconnectClientReply reply; - grpc::ClientContext context; - auto status = stub_->DisconnectClient(&context, disconnect_client_request, &reply); - if (!status.ok()) { - RAY_LOG(ERROR) << "Worker " << worker_id_ - << " failed to disconnect from raylet, msg: " - << status.error_message(); - } - return GrpcStatusToRayStatus(status); -} - -ray::Status RayletClient::SubmitTask(const ray::TaskSpecification &task_spec) { - RETURN_IF_DISCONNECTED(is_connected_); - SubmitTaskRequest submit_task_request; - submit_task_request.set_worker_id(worker_id_.Binary()); - submit_task_request.mutable_task_spec()->CopyFrom(task_spec.GetMessage()); - - auto callback = [this](const Status &status, const SubmitTaskReply &reply) { - if (!status.ok() && is_connected_) { - is_connected_ = false; - RAY_LOG(INFO) << "Worker " << worker_id_ - << " failed to send SubmitTaskRequest, msg: " << status.message(); - } - }; - - auto call = - client_call_manager_.CreateCall( - *stub_, &RayletService::Stub::PrepareAsyncSubmitTask, submit_task_request, - callback); - return call->GetStatus(); -} - -ray::Status RayletClient::GetTask(std::unique_ptr *task_spec) { - RETURN_IF_DISCONNECTED(is_connected_); - GetTaskRequest get_task_request; - get_task_request.set_worker_id(worker_id_.Binary()); - - grpc::ClientContext context; - GetTaskReply reply; - // The actual RPC. - auto status = stub_->GetTask(&context, get_task_request, &reply); - - if (status.ok()) { - resource_ids_.clear(); - // Parse resources that would be used by this assigned task. - for (int64_t i = 0; i < reply.fractional_resource_ids().size(); ++i) { - auto const &fractional_resource_ids = reply.fractional_resource_ids()[i]; - auto &acquired_resources = resource_ids_[fractional_resource_ids.resource_name()]; - - // Each resource includes a series of resource IDs (e.g., GPU 0) and corresponding - // amount for that resource ID. If the resource amount is fractional, then there - // should only be one resource ID. - size_t num_resource_ids = fractional_resource_ids.resource_ids().size(); - size_t num_resource_fractions = fractional_resource_ids.resource_fractions().size(); - RAY_CHECK(num_resource_ids == num_resource_fractions); - RAY_CHECK(num_resource_ids > 0); - for (size_t j = 0; j < num_resource_ids; ++j) { - int64_t resource_id = fractional_resource_ids.resource_ids()[j]; - double resource_fraction = fractional_resource_ids.resource_fractions()[j]; - if (num_resource_ids > 1) { - int64_t whole_fraction = resource_fraction; - RAY_CHECK(whole_fraction == resource_fraction); - } - acquired_resources.emplace_back(resource_id, resource_fraction); - } - } - task_spec->reset(new ray::TaskSpecification(reply.task_spec())); - } else { - *task_spec = nullptr; - RAY_LOG(INFO) << "Worker " << worker_id_ - << " failed to get task, msg: " << status.error_message(); - } - return GrpcStatusToRayStatus(status); -} - -ray::Status RayletClient::TaskDone() { - RETURN_IF_DISCONNECTED(is_connected_); - TaskDoneRequest task_done_request; - task_done_request.set_worker_id(worker_id_.Binary()); - - auto callback = [this](const Status &status, const TaskDoneReply &reply) { - if (!status.ok() && is_connected_) { - is_connected_ = false; - RAY_LOG(INFO) << "Worker " << worker_id_ - << " failed to send TaskDoneRequest, msg: " << status.message(); - } - }; - - auto call = - client_call_manager_.CreateCall( - *stub_, &RayletService::Stub::PrepareAsyncTaskDone, task_done_request, - callback); - return call->GetStatus(); -} - -ray::Status RayletClient::FetchOrReconstruct(const std::vector &object_ids, - bool fetch_only, - const TaskID ¤t_task_id) { - RETURN_IF_DISCONNECTED(is_connected_); - FetchOrReconstructRequest fetch_or_reconstruct_request; - fetch_or_reconstruct_request.set_fetch_only(fetch_only); - fetch_or_reconstruct_request.set_task_id(current_task_id.Binary()); - fetch_or_reconstruct_request.set_worker_id(worker_id_.Binary()); - IdVectorToProtobuf( - object_ids, fetch_or_reconstruct_request, - &FetchOrReconstructRequest::add_object_ids); - - // Callback to deal with reply. - auto callback = [this](const Status &status, const FetchOrReconstructReply &reply) { - if (!status.ok() && is_connected_) { - is_connected_ = false; - RAY_LOG(INFO) << "Worker " << worker_id_ - << " failed to send FetchOrReconstructRequest, msg: " - << status.message(); - } - }; - - auto call = - client_call_manager_ - .CreateCall( - *stub_, &RayletService::Stub::PrepareAsyncFetchOrReconstruct, - fetch_or_reconstruct_request, callback); - return call->GetStatus(); -} - -ray::Status RayletClient::NotifyUnblocked(const TaskID ¤t_task_id) { - RETURN_IF_DISCONNECTED(is_connected_); - NotifyUnblockedRequest notify_unblocked_request; - notify_unblocked_request.set_worker_id(worker_id_.Binary()); - notify_unblocked_request.set_task_id(current_task_id.Binary()); - - auto callback = [this](const Status &status, const NotifyUnblockedReply &reply) { - if (!status.ok() && is_connected_) { - is_connected_ = false; - RAY_LOG(INFO) << "Worker " << worker_id_ - << " failed to send NotifyUnblockedRequest, msg: " - << status.message(); - } - }; - - auto call = - client_call_manager_ - .CreateCall( - *stub_, &RayletService::Stub::PrepareAsyncNotifyUnblocked, - notify_unblocked_request, callback); - return call->GetStatus(); -} - -ray::Status RayletClient::Wait(const std::vector &object_ids, int num_returns, - int64_t timeout_milliseconds, bool wait_local, - const TaskID ¤t_task_id, WaitResultPair *result) { - RETURN_IF_DISCONNECTED(is_connected_); - WaitRequest wait_request; - wait_request.set_worker_id(worker_id_.Binary()); - wait_request.set_timeout(timeout_milliseconds); - wait_request.set_wait_local(wait_local); - wait_request.set_task_id(current_task_id.Binary()); - wait_request.set_num_ready_objects(num_returns); - IdVectorToProtobuf(object_ids, wait_request, - &WaitRequest::add_object_ids); - - grpc::ClientContext context; - WaitReply reply; - auto status = stub_->Wait(&context, wait_request, &reply); - - if (status.ok()) { - result->first = IdVectorFromProtobuf(reply.found()); - result->second = IdVectorFromProtobuf(reply.remaining()); - } else { - RAY_LOG(INFO) << "Worker " << worker_id_ - << " failed to send WaitRequest, msg: " << status.error_message(); - } - - return GrpcStatusToRayStatus(status); -} - -ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string &type, - const std::string &error_message, double timestamp) { - RETURN_IF_DISCONNECTED(is_connected_); - PushErrorRequest push_error_request; - push_error_request.set_job_id(job_id.Binary()); - push_error_request.set_type(type); - push_error_request.set_error_message(error_message); - push_error_request.set_timestamp(timestamp); - push_error_request.set_worker_id(worker_id_.Binary()); - - auto callback = [this](const Status &status, const PushErrorReply &reply) { - if (!status.ok() && is_connected_) { - is_connected_ = false; - RAY_LOG(INFO) << "Worker " << worker_id_ - << " failed to send PushErrorRequest, msg: " << status.message(); - } - }; - - auto call = - client_call_manager_.CreateCall( - *stub_, &RayletService::Stub::PrepareAsyncPushError, push_error_request, - callback); - return call->GetStatus(); -} - -ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_events) { - RETURN_IF_DISCONNECTED(is_connected_); - PushProfileEventsRequest push_profile_events_request; - push_profile_events_request.mutable_profile_table_data()->CopyFrom(profile_events); - push_profile_events_request.set_worker_id(worker_id_.Binary()); - - auto callback = [this](const Status &status, const PushProfileEventsReply &reply) { - if (!status.ok() && is_connected_) { - is_connected_ = false; - RAY_LOG(INFO) << "Worker " << worker_id_ - << " failed to send PushProfileEventsRequest, msg: " - << status.message(); - } - }; - - auto call = - client_call_manager_ - .CreateCall( - *stub_, &RayletService::Stub::PrepareAsyncPushProfileEvents, - push_profile_events_request, callback); - return call->GetStatus(); -} - -ray::Status RayletClient::FreeObjects(const std::vector &object_ids, - bool local_only, bool delete_creating_tasks) { - RETURN_IF_DISCONNECTED(is_connected_); - FreeObjectsInStoreRequest free_objects_request; - free_objects_request.set_local_only(local_only); - free_objects_request.set_delete_creating_tasks(delete_creating_tasks); - free_objects_request.set_worker_id(worker_id_.Binary()); - IdVectorToProtobuf( - object_ids, free_objects_request, &FreeObjectsInStoreRequest::add_object_ids); - - auto callback = [this](const Status &status, const FreeObjectsInStoreReply &reply) { - if (!status.ok() && is_connected_) { - is_connected_ = false; - RAY_LOG(INFO) << "Worker " << worker_id_ - << " failed to send FreeObjectsInStoreRequest, msg: " - << status.message(); - } - }; - - auto call = - client_call_manager_ - .CreateCall( - *stub_, &RayletService::Stub::PrepareAsyncFreeObjectsInStore, - free_objects_request, callback); - return call->GetStatus(); -} - -ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id, - ActorCheckpointID &checkpoint_id) { - RETURN_IF_DISCONNECTED(is_connected_); - PrepareActorCheckpointRequest prepare_actor_checkpoint_request; - prepare_actor_checkpoint_request.set_actor_id(actor_id.Binary()); - prepare_actor_checkpoint_request.set_worker_id(worker_id_.Binary()); - - grpc::ClientContext context; - PrepareActorCheckpointReply reply; - auto status = - stub_->PrepareActorCheckpoint(&context, prepare_actor_checkpoint_request, &reply); - - if (status.ok()) { - checkpoint_id = ActorCheckpointID::FromBinary(reply.checkpoint_id()); - } else { - RAY_LOG(INFO) << "Worker " << worker_id_ - << " failed to send PrepareActorCheckpointRequest, msg: " - << status.error_message(); - } - - return GrpcStatusToRayStatus(status); -} - -ray::Status RayletClient::NotifyActorResumedFromCheckpoint( - const ActorID &actor_id, const ActorCheckpointID &checkpoint_id) { - RETURN_IF_DISCONNECTED(is_connected_); - NotifyActorResumedFromCheckpointRequest notify_actor_resumed_from_checkpoint_request; - notify_actor_resumed_from_checkpoint_request.set_actor_id(actor_id.Binary()); - notify_actor_resumed_from_checkpoint_request.set_checkpoint_id(checkpoint_id.Binary()); - notify_actor_resumed_from_checkpoint_request.set_worker_id(worker_id_.Binary()); - - auto callback = [this](const Status &status, - const NotifyActorResumedFromCheckpointReply &reply) { - if (!status.ok() && is_connected_) { - is_connected_ = false; - RAY_LOG(INFO) << "NotifyActorResumedFromCheckpoint failed, msg: " - << status.message(); - } - }; - - auto call = - client_call_manager_ - .CreateCall( - *stub_, &RayletService::Stub::PrepareAsyncNotifyActorResumedFromCheckpoint, - notify_actor_resumed_from_checkpoint_request, callback); - return call->GetStatus(); -} - -ray::Status RayletClient::SetResource(const std::string &resource_name, - const double capacity, - const ray::ClientID &client_id) { - RETURN_IF_DISCONNECTED(is_connected_); - SetResourceRequest set_resource_request; - set_resource_request.set_resource_name(resource_name); - set_resource_request.set_capacity(capacity); - set_resource_request.set_client_id(client_id.Binary()); - set_resource_request.set_worker_id(worker_id_.Binary()); - - auto callback = [this](const Status &status, const SetResourceReply &reply) { - if (!status.ok() && is_connected_) { - is_connected_ = false; - RAY_LOG(INFO) << "SetResource failed, msg: " << status.message(); - } - }; - - auto call = client_call_manager_ - .CreateCall( - *stub_, &RayletService::Stub::PrepareAsyncSetResource, - set_resource_request, callback); - return call->GetStatus(); -} - -ray::Status RayletClient::RegisterClient() { - RegisterClientRequest register_client_request; - register_client_request.set_is_worker(is_worker_); - register_client_request.set_worker_id(worker_id_.Binary()); - register_client_request.set_worker_pid(getpid()); - register_client_request.set_job_id(job_id_.Binary()); - register_client_request.set_language(language_); - register_client_request.set_port(port_); - - grpc::ClientContext context; - RegisterClientReply reply; - auto status = stub_->RegisterClient(&context, register_client_request, &reply); - - if (!status.ok()) { - RAY_LOG(DEBUG) << "Worker " << worker_id_ - << " failed to register client, msg: " << status.error_message(); - } - - return GrpcStatusToRayStatus(status); -} - -void RayletClient::Heartbeat() { - if (!is_connected_) { - return; - } - HeartbeatRequest heartbeat_request; - heartbeat_request.set_is_worker(is_worker_); - heartbeat_request.set_worker_id(worker_id_.Binary()); - - auto callback = [this](const Status &status, const HeartbeatReply &reply) { - if (!status.ok() && is_connected_) { - is_connected_ = false; - RAY_LOG(INFO) << "Heartbeat failed, msg: " << status.message(); - } - }; - auto call = - client_call_manager_.CreateCall( - *stub_, &RayletService::Stub::PrepareAsyncHeartbeat, heartbeat_request, - callback); - - heartbeat_timer_.expires_from_now(boost::posix_time::milliseconds( - RayConfig::instance().heartbeat_timeout_milliseconds())); - heartbeat_timer_.async_wait([this](const boost::system::error_code &error) { - RAY_CHECK(!error); - Heartbeat(); - }); -} - -} // namespace rpc -} // namespace ray diff --git a/src/ray/rpc/raylet/raylet_server.h b/src/ray/rpc/raylet/raylet_server.h deleted file mode 100644 index 608cd26ff..000000000 --- a/src/ray/rpc/raylet/raylet_server.h +++ /dev/null @@ -1,258 +0,0 @@ -#ifndef RAY_RPC_RAYLET_SERVER_H -#define RAY_RPC_RAYLET_SERVER_H - -#include "src/ray/rpc/grpc_server.h" -#include "src/ray/rpc/server_call.h" - -#include "src/ray/protobuf/raylet.grpc.pb.h" -#include "src/ray/protobuf/raylet.pb.h" - -namespace ray { -namespace rpc { - -/// Implementations of the `RayletService`, check interface in -/// `src/ray/protobuf/raylet.proto`. -class RayletServiceHandler { - public: - virtual void HandleRegisterClientRequest(const RegisterClientRequest &request, - RegisterClientReply *reply, - SendReplyCallback send_reply_callback) = 0; - /// Handle a `SubmitTask` request. - /// The implementation can handle this request asynchronously. When handling is done, - /// the `send_reply_callback` should be called. - /// - /// \param[in] request The request message. - /// \param[out] reply The reply message. - /// \param[in] send_reply_callback The callback to be called when the request is done. - virtual void HandleSubmitTaskRequest(const SubmitTaskRequest &request, - SubmitTaskReply *reply, - SendReplyCallback send_reply_callback) = 0; - /// Handle a `DisconnectClient` request. - virtual void HandleDisconnectClientRequest(const DisconnectClientRequest &request, - DisconnectClientReply *reply, - SendReplyCallback send_reply_callback) = 0; - /// Handle a `GetTask` request. - virtual void HandleGetTaskRequest(const GetTaskRequest &request, GetTaskReply *reply, - SendReplyCallback send_reply_callback) = 0; - /// Handle a `TaskDone` request. - virtual void HandleTaskDoneRequest(const TaskDoneRequest &request, TaskDoneReply *reply, - SendReplyCallback send_reply_callback) = 0; - /// Handle a `HandleFetchOrReconstruct` request. - virtual void HandleFetchOrReconstructRequest(const FetchOrReconstructRequest &request, - FetchOrReconstructReply *reply, - SendReplyCallback send_reply_callback) = 0; - /// Handle a `HandleNotifyUnblocked` request. - virtual void HandleNotifyUnblockedRequest(const NotifyUnblockedRequest &request, - NotifyUnblockedReply *reply, - SendReplyCallback send_reply_callback) = 0; - /// Handle a `Wait` request. - virtual void HandleWaitRequest(const WaitRequest &request, WaitReply *reply, - SendReplyCallback send_reply_callback) = 0; - /// Handle a `PushError` request. - virtual void HandlePushErrorRequest(const PushErrorRequest &request, - PushErrorReply *reply, - SendReplyCallback send_reply_callback) = 0; - /// Handle a `PushProfileEvents` request. - virtual void HandlePushProfileEventsRequest(const PushProfileEventsRequest &request, - PushProfileEventsReply *reply, - SendReplyCallback send_reply_callback) = 0; - /// Handle a `FreeObjectsInStoreInObjectStore` request. - virtual void HandleFreeObjectsInStoreRequest(const FreeObjectsInStoreRequest &request, - FreeObjectsInStoreReply *reply, - SendReplyCallback send_reply_callback) = 0; - /// Handle a `PrepareActorCheckpoint` request. - virtual void HandlePrepareActorCheckpointRequest( - const PrepareActorCheckpointRequest &request, PrepareActorCheckpointReply *reply, - SendReplyCallback send_reply_callback) = 0; - /// Handle a `NotifyActorResumedFromCheckpoint` request. - virtual void HandleNotifyActorResumedFromCheckpointRequest( - const NotifyActorResumedFromCheckpointRequest &request, - NotifyActorResumedFromCheckpointReply *reply, - SendReplyCallback send_reply_callback) = 0; - /// Handle a `SetResource` request. - virtual void HandleSetResourceRequest(const SetResourceRequest &request, - SetResourceReply *reply, - SendReplyCallback send_reply_callback) = 0; - /// Handle a `SetResourceReply` request. - virtual void HandleHeartbeatRequest(const HeartbeatRequest &request, - HeartbeatReply *reply, - SendReplyCallback send_reply_callback) = 0; -}; - -/// The `GrpcService` for `RayletGrpcService`. -class RayletGrpcService : public GrpcService { - public: - /// Construct a `RayletGrpcService`. - /// - /// \param[in] io_service Service used to handle incoming requests - /// \param[in] handler The service handler that actually handle the requests. - RayletGrpcService(boost::asio::io_service &io_service, - RayletServiceHandler &service_handler) - : GrpcService(io_service), service_handler_(service_handler){}; - - protected: - grpc::Service &GetGrpcService() override { return service_; } - - void InitServerCallFactories( - const std::unique_ptr &cq, - std::vector, int>> - *server_call_factories_and_concurrencies) override { - // Initialize the factory for `RegisterClient` requests. - std::unique_ptr register_client_call_factory( - new ServerCallFactoryImpl( - service_, &RayletService::AsyncService::RequestRegisterClient, - service_handler_, &RayletServiceHandler::HandleRegisterClientRequest, cq, - main_service_)); - server_call_factories_and_concurrencies->emplace_back( - std::move(register_client_call_factory), 10); - - // Initialize the factory for `SubmitTask` requests. - std::unique_ptr submit_task_call_factory( - new ServerCallFactoryImpl( - service_, &RayletService::AsyncService::RequestSubmitTask, service_handler_, - &RayletServiceHandler::HandleSubmitTaskRequest, cq, main_service_)); - server_call_factories_and_concurrencies->emplace_back( - std::move(submit_task_call_factory), 20); - - // Initialize the factory for `DisconnectClient` requests. - std::unique_ptr disconnect_client_call_factory( - new ServerCallFactoryImpl( - service_, &RayletService::AsyncService::RequestDisconnectClient, - service_handler_, &RayletServiceHandler::HandleDisconnectClientRequest, cq, - main_service_)); - server_call_factories_and_concurrencies->emplace_back( - std::move(disconnect_client_call_factory), 10); - - // Initialize the factory for `GetTask` requests. - std::unique_ptr get_task_call_factory( - new ServerCallFactoryImpl( - service_, &RayletService::AsyncService::RequestGetTask, service_handler_, - &RayletServiceHandler::HandleGetTaskRequest, cq, main_service_)); - server_call_factories_and_concurrencies->emplace_back( - std::move(get_task_call_factory), 20); - - // Initialize the factory for `TaskDone` requests. - std::unique_ptr task_done_call_factory( - new ServerCallFactoryImpl( - service_, &RayletService::AsyncService::RequestTaskDone, service_handler_, - &RayletServiceHandler::HandleTaskDoneRequest, cq, main_service_)); - server_call_factories_and_concurrencies->emplace_back( - std::move(task_done_call_factory), 20); - - // Initialize the factory for `FetchOrReconstruct` requests. - std::unique_ptr fetch_or_reconstruct_call_factory( - new ServerCallFactoryImpl( - service_, &RayletService::AsyncService::RequestFetchOrReconstruct, - service_handler_, &RayletServiceHandler::HandleFetchOrReconstructRequest, cq, - main_service_)); - server_call_factories_and_concurrencies->emplace_back( - std::move(fetch_or_reconstruct_call_factory), 10); - - // Initialize the factory for `NotifyUnblocked` requests. - std::unique_ptr notify_unblocked_call_factory( - new ServerCallFactoryImpl( - service_, &RayletService::AsyncService::RequestNotifyUnblocked, - service_handler_, &RayletServiceHandler::HandleNotifyUnblockedRequest, cq, - main_service_)); - server_call_factories_and_concurrencies->emplace_back( - std::move(notify_unblocked_call_factory), 10); - - // Initialize the factory for `Wait` requests. - std::unique_ptr wait_call_factory( - new ServerCallFactoryImpl( - service_, &RayletService::AsyncService::RequestWait, service_handler_, - &RayletServiceHandler::HandleWaitRequest, cq, main_service_)); - server_call_factories_and_concurrencies->emplace_back(std::move(wait_call_factory), - 20); - - // Initialize the factory for `PushError` requests. - std::unique_ptr push_error_call_factory( - new ServerCallFactoryImpl( - service_, &RayletService::AsyncService::RequestPushError, service_handler_, - &RayletServiceHandler::HandlePushErrorRequest, cq, main_service_)); - server_call_factories_and_concurrencies->emplace_back( - std::move(push_error_call_factory), 10); - - // Initialize the factory for `PushProfileEvents` requests. - std::unique_ptr push_profile_events_call_factory( - new ServerCallFactoryImpl( - service_, &RayletService::AsyncService::RequestPushProfileEvents, - service_handler_, &RayletServiceHandler::HandlePushProfileEventsRequest, cq, - main_service_)); - server_call_factories_and_concurrencies->emplace_back( - std::move(push_profile_events_call_factory), 10); - - // Initialize the factory for `FreeObjectsInStore` requests. - std::unique_ptr free_objects_call_factory( - new ServerCallFactoryImpl( - service_, &RayletService::AsyncService::RequestFreeObjectsInStore, - service_handler_, &RayletServiceHandler::HandleFreeObjectsInStoreRequest, cq, - main_service_)); - server_call_factories_and_concurrencies->emplace_back( - std::move(free_objects_call_factory), 10); - - // Initialize the factory for `PrepareActorCheckpoint` requests. - std::unique_ptr prepare_actor_checkpoint_call_factory( - new ServerCallFactoryImpl( - service_, &RayletService::AsyncService::RequestPrepareActorCheckpoint, - service_handler_, &RayletServiceHandler::HandlePrepareActorCheckpointRequest, - cq, main_service_)); - server_call_factories_and_concurrencies->emplace_back( - std::move(prepare_actor_checkpoint_call_factory), 10); - - // Initialize the factory for `NotifyActorResumedFromCheckpoint` requests. - std::unique_ptr notify_actor_resumed_from_checkpoint_call_factory( - new ServerCallFactoryImpl( - service_, - &RayletService::AsyncService::RequestNotifyActorResumedFromCheckpoint, - service_handler_, - &RayletServiceHandler::HandleNotifyActorResumedFromCheckpointRequest, cq, - main_service_)); - server_call_factories_and_concurrencies->emplace_back( - std::move(notify_actor_resumed_from_checkpoint_call_factory), 10); - - // Initialize the factory for `SetResource` requests. - std::unique_ptr set_resource_call_factory( - new ServerCallFactoryImpl( - service_, &RayletService::AsyncService::RequestSetResource, service_handler_, - &RayletServiceHandler::HandleSetResourceRequest, cq, main_service_)); - server_call_factories_and_concurrencies->emplace_back( - std::move(set_resource_call_factory), 10); - - // Initialize the factory for `Heartbeat` requests. - std::unique_ptr heartbeat_call_factory( - new ServerCallFactoryImpl( - service_, &RayletService::AsyncService::RequestHeartbeat, service_handler_, - &RayletServiceHandler::HandleHeartbeatRequest, cq, main_service_)); - server_call_factories_and_concurrencies->emplace_back( - std::move(heartbeat_call_factory), 10); - } - - private: - /// The grpc async service object. - RayletService::AsyncService service_; - /// The service handler that actually handle the requests. - RayletServiceHandler &service_handler_; -}; - -} // namespace rpc -} // namespace ray - -#endif