diff --git a/BUILD.bazel b/BUILD.bazel index 33c4fc793..64b380751 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -52,7 +52,7 @@ proto_library( cc_proto_library( name = "node_manager_cc_proto", - deps = ["node_manager_proto"], + deps = [":node_manager_proto"], ) proto_library( @@ -62,7 +62,21 @@ proto_library( cc_proto_library( name = "object_manager_cc_proto", - deps = ["object_manager_proto"], + deps = [":object_manager_proto"], +) + +proto_library( + name = "raylet_proto", + srcs = ["src/ray/protobuf/raylet.proto"], + deps = [ + ":common_proto", + ":gcs_proto", + ], +) + +cc_proto_library( + name = "raylet_cc_proto", + deps = [":raylet_proto"], ) proto_library( @@ -91,7 +105,7 @@ cc_proto_library( # === Begin of rpc definitions === -# grpc common lib +# GRPC common lib. cc_library( name = "grpc_common_lib", srcs = glob([ @@ -141,7 +155,7 @@ cc_grpc_library( deps = [":object_manager_cc_proto"], ) -# Object manager server and client. +# Object manager rpc server and client. cc_library( name = "object_manager_rpc", hdrs = glob([ @@ -157,7 +171,35 @@ cc_library( ], ) -# worker gRPC lib. +# Raylet gRPC lib. +cc_grpc_library( + name = "raylet_cc_grpc", + srcs = [":raylet_proto"], + grpc_only = True, + deps = [":raylet_cc_proto"], +) + +# Raylet rpc server and client. +cc_library( + name = "raylet_rpc", + srcs = glob([ + "src/ray/rpc/raylet/*.cc", + ]), + hdrs = glob([ + "src/ray/rpc/raylet/*.h", + "src/ray/raylet/*.h", + ]), + copts = COPTS, + deps = [ + ":grpc_common_lib", + ":ray_common", + ":raylet_cc_grpc", + "@boost//:asio", + "@com_github_grpc_grpc//:grpc++", + ], +) + +# Worker gRPC lib. cc_grpc_library( name = "worker_cc_grpc", srcs = [":worker_proto"], @@ -165,7 +207,7 @@ cc_grpc_library( deps = [":worker_cc_proto"], ) -# worker server and client. +# Worker server and client. cc_library( name = "worker_rpc", hdrs = glob([ @@ -201,7 +243,6 @@ cc_library( copts = COPTS, deps = [ ":common_cc_proto", - ":node_manager_fbs", ":ray_util", "@boost//:asio", "@com_github_grpc_grpc//:grpc++", @@ -300,11 +341,11 @@ cc_library( deps = [ ":common_cc_proto", ":gcs", - ":node_manager_fbs", ":node_manager_rpc", ":object_manager", ":ray_common", ":ray_util", + ":raylet_rpc", ":stats_lib", ":worker_rpc", "@boost//:asio", @@ -375,7 +416,6 @@ cc_test( srcs = ["src/ray/raylet/lineage_cache_test.cc"], copts = COPTS, deps = [ - ":node_manager_fbs", ":raylet_lib", "@com_google_googletest//:gtest_main", ], @@ -386,7 +426,6 @@ cc_test( srcs = ["src/ray/raylet/reconstruction_policy_test.cc"], copts = COPTS, deps = [ - ":node_manager_fbs", ":object_manager", ":raylet_lib", "@com_google_googletest//:gtest_main", @@ -576,7 +615,6 @@ cc_library( deps = [ ":gcs_cc_proto", ":hiredis", - ":node_manager_fbs", ":node_manager_rpc", ":ray_common", ":ray_util", @@ -632,13 +670,6 @@ flatbuffer_cc_library( out_prefix = "src/ray/common/", ) -flatbuffer_cc_library( - name = "node_manager_fbs", - srcs = ["src/ray/raylet/format/node_manager.fbs"], - flatc_args = FLATC_ARGS, - out_prefix = "src/ray/raylet/format/", -) - flatbuffer_cc_library( name = "object_manager_fbs", srcs = ["src/ray/object_manager/format/object_manager.fbs"], diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index a1e11141e..af94d933e 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -37,9 +37,9 @@ public class RayletClientImpl implements RayletClient { private long client = 0; // TODO(qwang): JobId parameter can be removed once we embed jobId in driverId. - public RayletClientImpl(String schedulerSockName, UniqueId clientId, + public RayletClientImpl(String schedulerSockName, UniqueId workerId, boolean isWorker, JobId jobId) { - client = nativeInit(schedulerSockName, clientId.getBytes(), + client = nativeInit(schedulerSockName, workerId.getBytes(), isWorker, jobId.getBytes()); } diff --git a/python/ray/__init__.py b/python/ray/__init__.py index f19a44255..f68a990af 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -5,6 +5,13 @@ from __future__ import print_function import os import sys +# MUST import ray._raylet before pyarrow to initialize some global variables. +# It seems the library related to memory allocation in pyarrow will destroy the +# initialization of grpc if we import pyarrow at first. +# NOTE(JoeyJiang): See https://github.com/ray-project/ray/issues/5219 for more +# details. +import ray._raylet + if "pyarrow" in sys.modules: raise ImportError("Ray must be imported before pyarrow because Ray " "requires a specific version of pyarrow (which is " diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 83d48eafc..67216f53b 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -220,14 +220,14 @@ cdef class RayletClient: cdef unique_ptr[CRayletClient] client def __cinit__(self, raylet_socket, - ClientID client_id, + WorkerID worker_id, c_bool is_worker, JobID job_id): # We know that we are using Python, so just skip the language # parameter. # TODO(suquark): Should we allow unicode chars in "raylet_socket"? self.client.reset(new CRayletClient( - raylet_socket.encode("ascii"), client_id.native(), is_worker, + raylet_socket.encode("ascii"), worker_id.native(), is_worker, job_id.native(), LANGUAGE_PYTHON)) def disconnect(self): @@ -374,7 +374,7 @@ cdef class RayletClient: @property def client_id(self): - return ClientID(self.client.get().GetClientID().Binary()) + return ClientID(self.client.get().GetWorkerId().Binary()) @property def job_id(self): diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 18a248304..97bb318a0 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -88,16 +88,16 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil: cdef extern from "ray/protobuf/common.pb.h" nogil: - cdef cppclass CLanguage "Language": + cdef cppclass CLanguage "ray::rpc::Language": pass # This is a workaround for C++ enum class since Cython has no corresponding # representation. -cdef extern from "ray/protobuf/common.pb.h" namespace "Language" nogil: - cdef CLanguage LANGUAGE_PYTHON "Language::PYTHON" - cdef CLanguage LANGUAGE_CPP "Language::CPP" - cdef CLanguage LANGUAGE_JAVA "Language::JAVA" +cdef extern from "ray/protobuf/common.pb.h" namespace "ray::rpc::Language" nogil: + cdef CLanguage LANGUAGE_PYTHON "ray::rpc::Language::PYTHON" + cdef CLanguage LANGUAGE_CPP "ray::rpc::Language::CPP" + cdef CLanguage LANGUAGE_JAVA "ray::rpc::Language::JAVA" cdef extern from "ray/common/task/scheduling_resources.h" \ diff --git a/python/ray/includes/libraylet.pxd b/python/ray/includes/libraylet.pxd index 2372ec884..c0ff3e614 100644 --- a/python/ray/includes/libraylet.pxd +++ b/python/ray/includes/libraylet.pxd @@ -23,14 +23,14 @@ from ray.includes.task cimport CTaskSpec cdef extern from "ray/protobuf/gcs.pb.h" nogil: - cdef cppclass GCSProfileEvent "ProfileTableData::ProfileEvent": + cdef cppclass GCSProfileEvent "ray::rpc::ProfileTableData::ProfileEvent": void set_event_type(const c_string &value) void set_start_time(double value) void set_end_time(double value) c_string set_extra_data(const c_string &value) GCSProfileEvent() - cdef cppclass GCSProfileTableData "ProfileTableData": + cdef cppclass GCSProfileTableData "ray::rpc::ProfileTableData": void set_component_type(const c_string &value) void set_component_id(const c_string &value) void set_node_ip_address(const c_string &value) @@ -43,13 +43,12 @@ ctypedef unordered_map[c_string, c_vector[pair[int64_t, double]]] \ ctypedef pair[c_vector[CObjectID], c_vector[CObjectID]] WaitResultPair -cdef extern from "ray/raylet/raylet_client.h" nogil: - cdef cppclass CRayletClient "RayletClient": +cdef extern from "ray/rpc/raylet/raylet_client.h" namespace "ray::rpc" nogil: + cdef cppclass CRayletClient "ray::rpc::RayletClient": CRayletClient(const c_string &raylet_socket, - const CClientID &client_id, + const CWorkerID &worker_id, c_bool is_worker, const CJobID &job_id, const CLanguage &language) - CRayStatus Disconnect() CRayStatus SubmitTask(const CTaskSpec &task_spec) CRayStatus GetTask(unique_ptr[CTaskSpec] *task_spec) CRayStatus TaskDone() @@ -73,7 +72,8 @@ cdef extern from "ray/raylet/raylet_client.h" nogil: const CActorID &actor_id, const CActorCheckpointID &checkpoint_id) CRayStatus SetResource(const c_string &resource_name, const double capacity, const CClientID &client_Id) CLanguage GetLanguage() const - CClientID GetClientID() const + CWorkerID GetWorkerId() const CJobID GetJobID() const c_bool IsWorker() const + CRayStatus Disconnect() const ResourceMappingType &GetResourceIDs() const diff --git a/python/ray/tests/test_multi_node.py b/python/ray/tests/test_multi_node.py index 9d972801e..e4043957f 100644 --- a/python/ray/tests/test_multi_node.py +++ b/python/ray/tests/test_multi_node.py @@ -458,11 +458,9 @@ print("success") # Make sure the first driver ran to completion. assert "success" in out - nonexistent_id_bytes = _random_string() - nonexistent_id_hex = ray.utils.binary_to_hex(nonexistent_id_bytes) # Define a driver that creates one task that depends on a nonexistent # object. This task will be queued as waiting to execute. - driver_script = """ + driver_script_template = """ import time import ray ray.init(redis_address="{}") @@ -472,11 +470,15 @@ def g(x): g.remote(ray.ObjectID(ray.utils.hex_to_binary("{}"))) time.sleep(1) print("success") -""".format(redis_address, nonexistent_id_hex) +""" # Create some drivers and let them exit and make sure everything is # still alive. for _ in range(3): + nonexistent_id_bytes = _random_string() + nonexistent_id_hex = ray.utils.binary_to_hex(nonexistent_id_bytes) + driver_script = driver_script_template.format(redis_address, + nonexistent_id_hex) out = run_string_as_driver(driver_script) # Simulate the nonexistent dependency becoming available. ray.worker.global_worker.put_object( @@ -484,10 +486,8 @@ print("success") # Make sure the first driver ran to completion. assert "success" in out - nonexistent_id_bytes = _random_string() - nonexistent_id_hex = ray.utils.binary_to_hex(nonexistent_id_bytes) # Define a driver that calls `ray.wait` on a nonexistent object. - driver_script = """ + driver_script_template = """ import time import ray ray.init(redis_address="{}") @@ -497,11 +497,15 @@ def g(): g.remote() time.sleep(1) print("success") -""".format(redis_address, nonexistent_id_hex) +""" # Create some drivers and let them exit and make sure everything is # still alive. for _ in range(3): + nonexistent_id_bytes = _random_string() + nonexistent_id_hex = ray.utils.binary_to_hex(nonexistent_id_bytes) + driver_script = driver_script_template.format(redis_address, + nonexistent_id_hex) out = run_string_as_driver(driver_script) # Simulate the nonexistent dependency becoming available. ray.worker.global_worker.put_object( diff --git a/python/ray/worker.py b/python/ray/worker.py index e9729d243..57dd263e7 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -41,7 +41,6 @@ import ray.state from ray import ( ActorHandleID, ActorID, - ClientID, WorkerID, JobID, ObjectID, @@ -1923,7 +1922,7 @@ def connect(node, worker.raylet_client = ray._raylet.RayletClient( node.raylet_socket_name, - ClientID(worker.worker_id), + WorkerID(worker.worker_id), (mode == WORKER_MODE), worker.current_job_id, ) diff --git a/src/ray/common/grpc_util.h b/src/ray/common/grpc_util.h index eb3e98c69..ec8c9ab3c 100644 --- a/src/ray/common/grpc_util.h +++ b/src/ray/common/grpc_util.h @@ -4,7 +4,6 @@ #include #include #include - #include "status.h" namespace ray { @@ -73,6 +72,17 @@ inline std::vector VectorFromProtobuf( return std::vector(pb_repeated.begin(), pb_repeated.end()); } +template +using AddFunction = void (Message::*)(const ::std::string &value); +/// Add a vector of type ID to protobuf message. +template +inline void IdVectorToProtobuf(const std::vector &ids, Message &message, + AddFunction add_func) { + for (const auto &id : ids) { + (message.*add_func)(id.Binary()); + } +} + /// Converts a Protobuf `RepeatedField` to a vector of IDs. template inline std::vector IdVectorFromProtobuf( diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 462699002..b91b4c773 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -20,8 +20,12 @@ RAY_CONFIG(int64_t, ray_cookie, 0x5241590000000000) /// warning is logged that the handler is taking too long. RAY_CONFIG(int64_t, handler_warning_timeout_ms, 100) -/// The duration between heartbeats. These are sent by the raylet. +/// The duration between heartbeats. This value is used for both worker and raylet. RAY_CONFIG(int64_t, heartbeat_timeout_milliseconds, 100) +/// Worker heartbeats also use `heartbeat_timeout_milliseconds` as timer timeout period. +/// If a worker has not sent a heartbeat in the last `num_worker_heartbeats_timeout` +/// heartbeat intervals, raylet will mark this worker as dead. +RAY_CONFIG(int64_t, num_worker_heartbeats_timeout, 30) /// If a component has not sent a heartbeat in the last num_heartbeats_timeout /// heartbeat intervals, the raylet monitor process will report /// it as dead to the db_client table. @@ -152,6 +156,9 @@ RAY_CONFIG(uint32_t, num_actor_checkpoints_to_keep, 20) /// Maximum number of ids in one batch to send to GCS to delete keys. RAY_CONFIG(uint32_t, maximum_gcs_deletion_batch_size, 1000) +/// Number of times for a raylet client to retry to register. +RAY_CONFIG(int, num_raylet_client_retry_times, 25) + /// When getting objects from object store, print a warning every this number of attempts. RAY_CONFIG(uint32_t, object_store_get_warn_per_num_attempts, 50) diff --git a/src/ray/common/task/scheduling_resources.cc b/src/ray/common/task/scheduling_resources.cc index 5463b0933..4668caff4 100644 --- a/src/ray/common/task/scheduling_resources.cc +++ b/src/ray/common/task/scheduling_resources.cc @@ -674,37 +674,22 @@ std::string ResourceIdSet::ToString() const { return return_string; } -std::vector> ResourceIdSet::ToFlatbuf( - flatbuffers::FlatBufferBuilder &fbb) const { - std::vector> return_message; +std::vector ResourceIdSet::ToProtobuf() const { + std::vector resources; for (auto const &resource_pair : available_resources_) { - std::vector resource_ids; - std::vector resource_fractions; + rpc::ResourceIdSetInfo resource_id_set_info; + resource_id_set_info.set_resource_name(resource_pair.first); for (auto whole_id : resource_pair.second.WholeIds()) { - resource_ids.push_back(whole_id); - resource_fractions.push_back(1); + resource_id_set_info.add_resource_ids(whole_id); + resource_id_set_info.add_resource_fractions(1); } - for (auto const &fractional_pair : resource_pair.second.FractionalIds()) { - resource_ids.push_back(fractional_pair.first); - resource_fractions.push_back(fractional_pair.second.ToDouble()); + resource_id_set_info.add_resource_ids(fractional_pair.first); + resource_id_set_info.add_resource_fractions(fractional_pair.second.ToDouble()); } - - auto resource_id_set_message = protocol::CreateResourceIdSetInfo( - fbb, fbb.CreateString(resource_pair.first), fbb.CreateVector(resource_ids), - fbb.CreateVector(resource_fractions)); - - return_message.push_back(resource_id_set_message); + resources.emplace_back(resource_id_set_info); } - - return return_message; -} - -const std::string ResourceIdSet::Serialize() const { - flatbuffers::FlatBufferBuilder fbb; - auto resource_id_set_flatbuf = ToFlatbuf(fbb); - fbb.Finish(fbb.CreateVector(resource_id_set_flatbuf)); - return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize()); + return resources; } /// SchedulingResources class implementation diff --git a/src/ray/common/task/scheduling_resources.h b/src/ray/common/task/scheduling_resources.h index 045c0307b..7076fc408 100644 --- a/src/ray/common/task/scheduling_resources.h +++ b/src/ray/common/task/scheduling_resources.h @@ -6,7 +6,7 @@ #include #include -#include "ray/raylet/format/node_manager_generated.h" +#include "ray/protobuf/common.pb.h" namespace ray { @@ -422,21 +422,13 @@ class ResourceIdSet { /// \return A human-readable string version of the object. std::string ToString() const; - /// \brief Serialize this object using flatbuffers. + /// \brief Convert this object to a vector of protobuf `ResourceIdSetInfo`s. /// - /// \param fbb A flatbuffer builder object. - /// \return A flatbuffer serialized version of this object. - std::vector> ToFlatbuf( - flatbuffers::FlatBufferBuilder &fbb) const; - - /// \brief Serialize this object as a string. - /// - /// \return A serialized string of this object. - /// TODO(zhijunfu): this can be removed after raylet client is migrated to grpc. - const std::string Serialize() const; + /// \return A vector inclusing resource id set infos. + std::vector ToProtobuf() const; private: - /// A mapping from reosurce name to a set of resource IDs for that resource. + /// A mapping from resource name to a set of resource IDs for that resource. std::unordered_map available_resources_; }; diff --git a/src/ray/common/task/task.h b/src/ray/common/task/task.h index 4c37ebdd4..aba8e0eb3 100644 --- a/src/ray/common/task/task.h +++ b/src/ray/common/task/task.h @@ -3,9 +3,9 @@ #include -#include "ray/common/task/task_common.h" #include "ray/common/task/task_execution_spec.h" #include "ray/common/task/task_spec.h" +#include "ray/protobuf/common.pb.h" namespace ray { diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index ef3188266..5f5e439a2 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -6,7 +6,7 @@ #include "ray/common/buffer.h" #include "ray/common/id.h" #include "ray/common/task/task_spec.h" -#include "ray/raylet/raylet_client.h" +#include "ray/rpc/raylet/raylet_client.h" #include "ray/util/util.h" namespace ray { diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index e49ca9972..4a7dc4e59 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -27,9 +27,9 @@ CoreWorker::CoreWorker( // so that the worker (java/python .etc) can retrieve and handle the error // instead of crashing. raylet_client_ = std::unique_ptr(new RayletClient( - raylet_socket_, ClientID::FromBinary(worker_context_.GetWorkerID().Binary()), - (worker_type_ == WorkerType::WORKER), worker_context_.GetCurrentJobID(), language_, - rpc_server_port)); + raylet_socket_, WorkerID::FromBinary(worker_context_.GetWorkerID().Binary()), + (worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(), + language_, rpc_server_port)); } } // namespace ray diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 97816bdda..447a1f922 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -7,10 +7,12 @@ #include "ray/core_worker/object_interface.h" #include "ray/core_worker/task_execution.h" #include "ray/core_worker/task_interface.h" -#include "ray/raylet/raylet_client.h" +#include "ray/rpc/raylet/raylet_client.h" namespace ray { +using rpc::RayletClient; + /// The root class that contains all the core and language-independent functionalities /// of the worker. This class is supposed to be used to implement app-language (Java, /// Python, etc) workers. diff --git a/src/ray/core_worker/core_worker_test.cc b/src/ray/core_worker/core_worker_test.cc index c01941d77..c0ce4e9b9 100644 --- a/src/ray/core_worker/core_worker_test.cc +++ b/src/ray/core_worker/core_worker_test.cc @@ -5,7 +5,7 @@ #include "ray/common/buffer.h" #include "ray/core_worker/context.h" #include "ray/core_worker/core_worker.h" -#include "ray/raylet/raylet_client.h" +#include "ray/rpc/raylet/raylet_client.h" #include #include diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.cc b/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.cc index 3c7bb43a0..5c77e7b07 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_objectstore_ObjectInterfaceImpl.cc @@ -5,6 +5,8 @@ #include "ray/core_worker/lib/java/jni_utils.h" #include "ray/core_worker/object_interface.h" +using ray::rpc::RayletClient; + inline ray::CoreWorkerObjectInterface *GetObjectInterfaceFromPointer( jlong nativeObjectInterfacePointer) { return reinterpret_cast(nativeObjectInterfacePointer); diff --git a/src/ray/core_worker/object_interface.h b/src/ray/core_worker/object_interface.h index cd9e461ee..2f7ddaacd 100644 --- a/src/ray/core_worker/object_interface.h +++ b/src/ray/core_worker/object_interface.h @@ -11,6 +11,8 @@ namespace ray { +using rpc::RayletClient; + class CoreWorker; class CoreWorkerStoreProvider; diff --git a/src/ray/core_worker/store_provider/local_plasma_provider.h b/src/ray/core_worker/store_provider/local_plasma_provider.h index d67b916b9..912709c46 100644 --- a/src/ray/core_worker/store_provider/local_plasma_provider.h +++ b/src/ray/core_worker/store_provider/local_plasma_provider.h @@ -7,7 +7,7 @@ #include "ray/common/status.h" #include "ray/core_worker/common.h" #include "ray/core_worker/store_provider/store_provider.h" -#include "ray/raylet/raylet_client.h" +#include "ray/rpc/raylet/raylet_client.h" namespace ray { diff --git a/src/ray/core_worker/store_provider/plasma_store_provider.h b/src/ray/core_worker/store_provider/plasma_store_provider.h index 1ef6a8b1d..797b71834 100644 --- a/src/ray/core_worker/store_provider/plasma_store_provider.h +++ b/src/ray/core_worker/store_provider/plasma_store_provider.h @@ -8,10 +8,12 @@ #include "ray/core_worker/common.h" #include "ray/core_worker/store_provider/local_plasma_provider.h" #include "ray/core_worker/store_provider/store_provider.h" -#include "ray/raylet/raylet_client.h" +#include "ray/rpc/raylet/raylet_client.h" namespace ray { +using rpc::RayletClient; + class CoreWorker; /// The class provides implementations for accessing plasma store, which includes both diff --git a/src/ray/core_worker/task_interface.h b/src/ray/core_worker/task_interface.h index e1ec502b6..e261802fe 100644 --- a/src/ray/core_worker/task_interface.h +++ b/src/ray/core_worker/task_interface.h @@ -15,6 +15,8 @@ namespace ray { +using rpc::RayletClient; + class CoreWorker; /// Options of a non-actor-creation task. diff --git a/src/ray/core_worker/transport/raylet_transport.cc b/src/ray/core_worker/transport/raylet_transport.cc index 9706b2ce4..9f4459126 100644 --- a/src/ray/core_worker/transport/raylet_transport.cc +++ b/src/ray/core_worker/transport/raylet_transport.cc @@ -56,8 +56,8 @@ void CoreWorkerRayletTaskReceiver::HandleAssignTask( // rpc reply first before the NotifyUnblocked message arrives, // as they use different connections, the `TaskDone` message is sent // to raylet via the same connection so the order is guaranteed. - raylet_client_->TaskDone(); - // send rpc reply. + RAY_UNUSED(raylet_client_->TaskDone()); + // Send rpc reply. send_reply_callback(status, nullptr, nullptr); } diff --git a/src/ray/core_worker/transport/raylet_transport.h b/src/ray/core_worker/transport/raylet_transport.h index 0ba8feb5e..10ed146c0 100644 --- a/src/ray/core_worker/transport/raylet_transport.h +++ b/src/ray/core_worker/transport/raylet_transport.h @@ -5,11 +5,13 @@ #include "ray/core_worker/object_interface.h" #include "ray/core_worker/transport/transport.h" -#include "ray/raylet/raylet_client.h" +#include "ray/rpc/raylet/raylet_client.h" #include "ray/rpc/worker/worker_server.h" namespace ray { +using rpc::RayletClient; + /// In raylet task submitter and receiver, a task is submitted to raylet, and possibly /// gets forwarded to another raylet on which node the task should be executed, and /// then a worker on that node gets this task and starts executing it. diff --git a/src/ray/object_manager/object_manager.cc b/src/ray/object_manager/object_manager.cc index b8458bcbe..5740d190c 100644 --- a/src/ray/object_manager/object_manager.cc +++ b/src/ray/object_manager/object_manager.cc @@ -769,9 +769,8 @@ void ObjectManager::SpreadFreeObjectsRequest( const std::vector> &rpc_clients) { // This code path should be called from node manager. rpc::FreeObjectsRequest free_objects_request; - for (const auto &e : object_ids) { - free_objects_request.add_object_ids(e.Binary()); - } + IdVectorToProtobuf( + object_ids, free_objects_request, &rpc::FreeObjectsRequest::add_object_ids); for (auto &rpc_client : rpc_clients) { rpc_client->FreeObjects(free_objects_request, [](const Status &status, diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 48a3f1cf7..6c6b01b5e 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -11,6 +11,18 @@ enum Language { CPP = 2; } +// Resource id set info. +message ResourceIdSetInfo { + // The name of the resource. + bytes resource_name = 1; + // The resource IDs reserved for this worker. + repeated uint64 resource_ids = 2; + // The fraction of each resource ID that is reserved for this worker. Note + // that the length of this list must be the same as the length of + // resource_ids. + repeated double resource_fractions = 3; +} + // Type of a worker. enum WorkerType { WORKER = 0; diff --git a/src/ray/protobuf/raylet.proto b/src/ray/protobuf/raylet.proto new file mode 100644 index 000000000..0cff32d1f --- /dev/null +++ b/src/ray/protobuf/raylet.proto @@ -0,0 +1,203 @@ +syntax = "proto3"; + +package ray.rpc; + +import "src/ray/protobuf/common.proto"; +import "src/ray/protobuf/gcs.proto"; + +/// Service request and reply messages. +message RegisterClientRequest { + // Indicates the client is a worker or a driver. + bool is_worker = 1; + // The worker id. + bytes worker_id = 2; + // The process ID of this worker. + uint32 worker_pid = 3; + // The job ID. + bytes job_id = 4; + // Language of this worker. + Language language = 5; + // Port that this worker is listening on. + // If port > 0, then worker will listen to this port and wait for + // raylet to push tasks, instead of invoking GetTask(). + int32 port = 6; +} +message RegisterClientReply { + repeated int32 gpu_ids = 1; +} + +message SubmitTaskRequest { + TaskSpec task_spec = 1; +} +message SubmitTaskReply { +} + +message DisconnectClientRequest { + bytes worker_id = 1; +} +message DisconnectClientReply { +} + +message GetTaskRequest { + bytes worker_id = 1; +} +message GetTaskReply { + // A string of bytes representing the task specification. + bytes task_spec = 1; + // A list of the resources reserved for this worker. + repeated ResourceIdSetInfo fractional_resource_ids = 2; +} + +message TaskDoneRequest { + bytes worker_id = 1; +} +message TaskDoneReply { +} + +message FetchOrReconstructRequest { + // List of object IDs of the objects that we want to reconstruct or fetch. + repeated bytes object_ids = 1; + // Indicates that we only want to fetch objects, not reconstruct them. + bool fetch_only = 2; + // The current task ID. If fetch_only is false, then this task is blocked. + bytes task_id = 3; + // The worker ID. + bytes worker_id = 4; +} +message FetchOrReconstructReply { +} + +message NotifyUnblockedRequest { + bytes worker_id = 1; + // The current task ID. This task is no longer blocked. + bytes task_id = 2; +} +message NotifyUnblockedReply { +} + +message WaitRequest { + // List of object ids we'll be waiting on. + repeated bytes object_ids = 1; + // Number of objects expected to be returned, if available. + uint64 num_ready_objects = 2; + // Timeout in milliseconds. + int64 timeout = 3; + // Whether to wait until objects appear locally. + bool wait_local = 4; + // The current task ID. If there are less than num_ready_objects local, then + // this task is blocked. + bytes task_id = 5; + // The worker ID. + bytes worker_id = 6; +} +message WaitReply { + // List of object ids found. + repeated bytes found = 1; + // List of object ids not found. + repeated bytes remaining = 2; +} + +message PushErrorRequest { + // The job id that the error is for. + bytes job_id = 1; + // The type of the error. + bytes type = 2; + // The error message. + bytes error_message = 3; + // The timestamp of the error message. + double timestamp = 4; +} +message PushErrorReply { +} + +message PushProfileEventsRequest { + ProfileTableData profile_table_data = 1; +} +message PushProfileEventsReply { +} + +message FreeObjectsInStoreRequest { + // Whether keep this request within the local object store + // or send it to all of the object stores. + bool local_only = 1; + // Whether also delete objects' creating tasks from GCS. + bool delete_creating_tasks = 2; + // List of object ids to delete from the object store. + repeated bytes object_ids = 3; +} +message FreeObjectsInStoreReply { +} + +message PrepareActorCheckpointRequest { + bytes actor_id = 1; + bytes worker_id = 2; +} +message PrepareActorCheckpointReply { + bytes checkpoint_id = 1; +} + +message NotifyActorResumedFromCheckpointRequest { + // ID of the actor that resumed. + bytes actor_id = 1; + // ID of the checkpoint from which the actor was resumed. + bytes checkpoint_id = 2; +} +message NotifyActorResumedFromCheckpointReply { +} + +message SetResourceRequest { + // Name of the resource to be set. + bytes resource_name = 1; + // Capacity of the resource to be set. + double capacity = 2; + // Client ID where this resource will be set. + bytes client_id = 3; +} +message SetResourceReply { +} + +message HeartbeatRequest { + bytes worker_id = 1; + bool is_worker = 2; +} +message HeartbeatReply { +} + +/// Worker-to-raylet RPC service interface. +service RayletService { + // Register a new worker to the raylet. + rpc RegisterClient(RegisterClientRequest) returns (RegisterClientReply); + // Submit a task to the raylet. + rpc SubmitTask(SubmitTaskRequest) returns (SubmitTaskReply); + // Disconnect this client from raylet gracefully. + rpc DisconnectClient(DisconnectClientRequest) returns (DisconnectClientReply); + // Get a new task from the raylet. + rpc GetTask(GetTaskRequest) returns (GetTaskReply); + // Notify the raylet that a task is finished. + rpc TaskDone(TaskDoneRequest) returns (TaskDoneReply); + // Reconstruct or fetch possibly lost objects. + rpc FetchOrReconstruct(FetchOrReconstructRequest) returns (FetchOrReconstructReply); + // For a worker that was blocked on some object(s), tell the raylet + // that the worker is now unblocked. + rpc NotifyUnblocked(NotifyUnblockedRequest) returns (NotifyUnblockedReply); + // Wait for objects to be ready either from local or remote plasma stores. + // The `WaitReply` contains the objects found and objects remaining. + rpc Wait(WaitRequest) returns (WaitReply); + // Push an error to the relevant driver. + rpc PushError(PushErrorRequest) returns (PushErrorReply); + // Push some profiling events to the GCS. When sending this message to the + // node manager, the message itself is serialized as a ProfileTableData object. + rpc PushProfileEvents(PushProfileEventsRequest) returns (PushProfileEventsReply); + // Free the objects in plasma objects store. + rpc FreeObjectsInStore(FreeObjectsInStoreRequest) returns (FreeObjectsInStoreReply); + // Request raylet backend to prepare a checkpoint for an actor. + rpc PrepareActorCheckpoint(PrepareActorCheckpointRequest) + returns (PrepareActorCheckpointReply); + // Notify raylet backend that an actor was resumed from a checkpoint. + rpc NotifyActorResumedFromCheckpoint(NotifyActorResumedFromCheckpointRequest) + returns (NotifyActorResumedFromCheckpointReply); + // Set dynamic custom resource. + rpc SetResource(SetResourceRequest) returns (SetResourceReply); + // Send a heartbeat message to raylet. + rpc Heartbeat(HeartbeatRequest) returns (HeartbeatReply); +} diff --git a/src/ray/protobuf/worker.proto b/src/ray/protobuf/worker.proto index 3c2d30ab6..3e2f83055 100644 --- a/src/ray/protobuf/worker.proto +++ b/src/ray/protobuf/worker.proto @@ -8,9 +8,7 @@ message AssignTaskRequest { // The task to be pushed. Task task = 1; // A list of the resources reserved for this worker. - // TODO(zhijunfu): `resource_ids` is represented as - // flatbutters-serialized bytes, will be moved to protobuf later. - bytes resource_ids = 2; + repeated ResourceIdSetInfo resource_ids = 2; } message AssignTaskReply { diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index fb4390fb5..51cfaca27 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -4,13 +4,17 @@ #include "ray/common/id.h" #include "ray/core_worker/lib/java/jni_utils.h" -#include "ray/raylet/raylet_client.h" +#include "ray/rpc/raylet/raylet_client.h" #include "ray/util/logging.h" #ifdef __cplusplus extern "C" { #endif +using ray::ClientID; +using ray::WorkerID; +using ray::rpc::RayletClient; + /* * Class: org_ray_runtime_raylet_RayletClientImpl * Method: nativeInit @@ -19,7 +23,7 @@ extern "C" { JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit( JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker, jbyteArray jobId) { - const auto worker_id = JavaByteArrayToId(env, workerId); + const auto worker_id = JavaByteArrayToId(env, workerId); const auto job_id = JavaByteArrayToId(env, jobId); const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE); auto raylet_client = new std::unique_ptr( diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 48b850d96..c552b488e 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -7,7 +7,6 @@ #include "ray/common/task/task_execution_spec.h" #include "ray/common/task/task_spec.h" #include "ray/common/task/task_util.h" -#include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/lineage_cache.h" namespace ray { diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 7efded15b..9b1d84d31 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -147,7 +147,7 @@ int main(int argc, char *argv[]) { RAY_LOG(DEBUG) << "Starting object manager with configuration: \n" << "rpc_service_threads_number = " << object_manager_config.rpc_service_threads_number - << "object_chunk_size = " << object_manager_config.object_chunk_size; + << ", object_chunk_size = " << object_manager_config.object_chunk_size; // Initialize the node manager. boost::asio::io_service main_service; @@ -171,6 +171,7 @@ int main(int argc, char *argv[]) { server.reset(); gcs_client->Disconnect(); main_service.stop(); + RAY_LOG(INFO) << "Raylet server received SIGTERM message, shutting down..."; }; RAY_CHECK_OK(gcs_client->client_table().Disconnect(shutdown_callback)); // Give a timeout for this Disconnect operation. diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 908ea666c..2a97931db 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -6,7 +6,6 @@ #include "ray/common/common_protocol.h" #include "ray/common/id.h" -#include "ray/raylet/format/node_manager_generated.h" #include "ray/stats/stats.h" namespace { @@ -105,6 +104,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, actor_registry_(), node_manager_server_("NodeManager", config.node_manager_port), node_manager_service_(io_service, *this), + raylet_service_(io_service, *this), client_call_manager_(io_service) { RAY_CHECK(heartbeat_period_.count() > 0); // Initialize the resource map with own cluster resource configuration. @@ -249,6 +249,7 @@ void NodeManager::KillWorker(std::shared_ptr worker) { // up its state before force killing. The client socket will be closed // and the worker struct will be freed after the timeout. kill(worker->Pid(), SIGTERM); + worker->MarkAsBeingKilled(); auto retry_timer = std::make_shared(io_service_); auto retry_duration = boost::posix_time::milliseconds( @@ -272,13 +273,10 @@ void NodeManager::HandleJobTableUpdate(const JobID &id, auto workers = worker_pool_.GetWorkersRunningTasksForJob(job_id); // Kill all the workers. The actual cleanup for these workers is done - // later when we receive the DisconnectClient message from them. + // later when the worker heartbeats timeout. for (const auto &worker : workers) { // Clean up any open ray.wait calls that the worker made. task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId()); - // Mark the worker as dead so further messages from it are ignored - // (except DisconnectClient). - worker->MarkDead(); // Then kill the worker process. KillWorker(worker); } @@ -340,6 +338,18 @@ void NodeManager::Heartbeat() { last_debug_dump_at_ms_ = now_ms; } + // Check worker heartbeat timeout times. + std::vector> dead_workers; + worker_pool_.TickHeartbeatTimer(RayConfig::instance().num_worker_heartbeats_timeout(), + &dead_workers); + if (!dead_workers.empty()) { + for (const auto &worker : dead_workers) { + RAY_LOG(INFO) << "Worker " << worker->GetWorkerId() + << " dead because of timeout, pid: " << worker->Pid(); + ProcessDisconnectClientMessage(worker->GetWorkerId(), worker->IsBeingKilled()); + } + } + // Reset the timer. heartbeat_timer_.expires_from_now(heartbeat_period_); heartbeat_timer_.async_wait([this](const boost::system::error_code &error) { @@ -682,9 +692,12 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, } } -void NodeManager::ProcessNewClient(LocalClientConnection &client) { - // The new client is a worker, so begin listening for messages. - client.ProcessMessages(); +void NodeManager::CleanUpTasksForFinishedJob(const JobID &job_id) { + auto tasks_to_remove = local_queues_.GetTaskIdsForJob(job_id); + task_dependency_manager_.RemoveTasksAndRelatedObjects(tasks_to_remove); + // NOTE(swang): SchedulingQueue::RemoveTasks modifies its argument so we must + // call it last. + local_queues_.RemoveTasks(tasks_to_remove); } // A helper function to create a mapping from resource shapes to @@ -701,7 +714,7 @@ std::unordered_map> MakeTasksWithResources( void NodeManager::DispatchTasks( const std::unordered_map> &tasks_with_resources) { - std::unordered_set removed_task_ids; + std::unordered_set assigned_task_ids; for (const auto &it : tasks_with_resources) { const auto &task_resources = it.first; for (const auto &task_id : it.second) { @@ -712,146 +725,51 @@ void NodeManager::DispatchTasks( break; } if (AssignTask(task)) { - removed_task_ids.insert(task_id); + assigned_task_ids.insert(task_id); } } } - // Move the ASSIGNED task to the SWAP queue so that we remember that we have - // it queued locally. Once the GetTaskReply has been sent, the task will get - // re-queued, depending on whether the message succeeded or not. - local_queues_.MoveTasks(removed_task_ids, TaskState::READY, TaskState::SWAP); + + // Move the ASSIGNED task to the RUNNING queue. + // We should move task outside `AssignTask` function because removing + // task might influence the iterator. + local_queues_.MoveTasks(assigned_task_ids, TaskState::READY, TaskState::RUNNING); } -void NodeManager::ProcessClientMessage( - const std::shared_ptr &client, int64_t message_type, - const uint8_t *message_data) { - auto registered_worker = worker_pool_.GetRegisteredWorker(client); - auto message_type_value = static_cast(message_type); - RAY_LOG(DEBUG) << "[Worker] Message " - << protocol::EnumNameMessageType(message_type_value) << "(" - << message_type << ") from worker with PID " - << (registered_worker ? std::to_string(registered_worker->Pid()) - : "nil"); - if (registered_worker && registered_worker->IsDead()) { - // For a worker that is marked as dead (because the job has died already), - // all the messages are ignored except DisconnectClient. - if ((message_type_value != protocol::MessageType::DisconnectClient) && - (message_type_value != protocol::MessageType::IntentionalDisconnectClient)) { - // Listen for more messages. - client->ProcessMessages(); - return; - } - } +void NodeManager::HandleRegisterClientRequest( + const rpc::RegisterClientRequest &request, rpc::RegisterClientReply *reply, + rpc::SendReplyCallback send_reply_callback) { + // Client id in register client is treated as worker id. + const WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); + auto worker = + std::make_shared(worker_id, request.worker_pid(), request.language(), + request.port(), client_call_manager_); - switch (message_type_value) { - case protocol::MessageType::RegisterClientRequest: { - ProcessRegisterClientRequestMessage(client, message_data); - } break; - case protocol::MessageType::GetTask: { - RAY_CHECK(!registered_worker->UsePush()); - HandleWorkerAvailable(client); - } break; - case protocol::MessageType::TaskDone: { - RAY_CHECK(registered_worker->UsePush()); - HandleWorkerAvailable(client); - } break; - case protocol::MessageType::DisconnectClient: { - ProcessDisconnectClientMessage(client); - // We don't need to receive future messages from this client, - // because it's already disconnected. - return; - } break; - case protocol::MessageType::IntentionalDisconnectClient: { - ProcessDisconnectClientMessage(client, /* intentional_disconnect = */ true); - // We don't need to receive future messages from this client, - // because it's already disconnected. - return; - } break; - case protocol::MessageType::SubmitTask: { - ProcessSubmitTaskMessage(message_data); - } break; - case protocol::MessageType::SetResourceRequest: { - ProcessSetResourceRequest(client, message_data); - } break; - case protocol::MessageType::FetchOrReconstruct: { - ProcessFetchOrReconstructMessage(client, message_data); - } break; - case protocol::MessageType::NotifyUnblocked: { - auto message = flatbuffers::GetRoot(message_data); - HandleTaskUnblocked(client, from_flatbuf(*message->task_id())); - } break; - case protocol::MessageType::WaitRequest: { - ProcessWaitRequestMessage(client, message_data); - } break; - case protocol::MessageType::PushErrorRequest: { - ProcessPushErrorRequestMessage(message_data); - } break; - case protocol::MessageType::PushProfileEventsRequest: { - auto fbs_message = flatbuffers::GetRoot(message_data); - rpc::ProfileTableData profile_table_data; - RAY_CHECK( - profile_table_data.ParseFromArray(fbs_message->data(), fbs_message->size())); - RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_table_data)); - } break; - case protocol::MessageType::FreeObjectsInObjectStoreRequest: { - auto message = flatbuffers::GetRoot(message_data); - std::vector object_ids = from_flatbuf(*message->object_ids()); - // Clean up objects from the object store. - object_manager_.FreeObjects(object_ids, message->local_only()); - if (message->delete_creating_tasks()) { - // Clean up their creating tasks from GCS. - std::vector creating_task_ids; - for (const auto &object_id : object_ids) { - creating_task_ids.push_back(object_id.TaskId()); - } - gcs_client_->raylet_task_table().Delete(JobID::Nil(), creating_task_ids); - } - } break; - case protocol::MessageType::PrepareActorCheckpointRequest: { - ProcessPrepareActorCheckpointRequest(client, message_data); - } break; - case protocol::MessageType::NotifyActorResumedFromCheckpoint: { - ProcessNotifyActorResumedFromCheckpoint(message_data); - } break; + RAY_LOG(DEBUG) << "Received a RegisterClientRequest, worker id: " << worker_id + << ", is worker: " << request.is_worker() + << ", pid: " << request.worker_pid(); - default: - RAY_LOG(FATAL) << "Received unexpected message type " << message_type; - } - - // Listen for more messages. - client->ProcessMessages(); -} - -void NodeManager::ProcessRegisterClientRequestMessage( - const std::shared_ptr &client, const uint8_t *message_data) { - client->Register(); - auto message = flatbuffers::GetRoot(message_data); - Language language = static_cast(message->language()); - WorkerID worker_id = from_flatbuf(*message->worker_id()); - auto worker = std::make_shared(worker_id, message->worker_pid(), language, - message->port(), client, client_call_manager_); - if (message->is_worker()) { + if (request.is_worker()) { // Register the new worker. bool use_push_task = worker->UsePush(); - auto connection = worker->Connection(); - worker_pool_.RegisterWorker(std::move(worker)); + worker_pool_.RegisterWorker(worker_id, std::move(worker)); if (use_push_task) { // only call `HandleWorkerAvailable` when push mode is used. - HandleWorkerAvailable(connection); + HandleWorkerAvailable(worker_id); } } else { // Register the new driver. - const JobID job_id = from_flatbuf(*message->job_id()); - // Compute a dummy driver task id from a given driver. - const TaskID driver_task_id = TaskID::ComputeDriverTaskId(worker_id); + auto driver_task_id = TaskID::ComputeDriverTaskId(worker_id); + auto job_id = JobID::FromBinary(request.job_id()); worker->AssignTaskId(driver_task_id); worker->AssignJobId(job_id); - worker_pool_.RegisterDriver(std::move(worker)); + worker_pool_.RegisterDriver(worker_id, std::move(worker)); local_queues_.AddDriverTaskId(driver_task_id); RAY_CHECK_OK(gcs_client_->job_table().AppendJobData( job_id, /*is_dead=*/false, std::time(nullptr), - initial_config_.node_manager_address, message->worker_pid())); + initial_config_.node_manager_address, request.worker_pid())); } + send_reply_callback(Status::OK(), nullptr, nullptr); } void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_local, @@ -882,8 +800,8 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca if (was_local) { // If the actor was local, immediately update the state in actor registry. // So if we receive any actor tasks before we receive GCS notification, - // these tasks can be correctly routed to the `MethodsWaitingForActorCreation` queue, - // instead of being assigned to the dead actor. + // these tasks can be correctly routed to the `MethodsWaitingForActorCreation` + // queue, instead of being assigned to the dead actor. HandleActorStateTransition(actor_id, ActorRegistration(new_actor_data)); } @@ -898,40 +816,62 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca RAY_CHECK_OK(gcs_client_->Actors().AsyncUpdate(actor_id, actor_notification, done)); } -void NodeManager::HandleWorkerAvailable( - const std::shared_ptr &client) { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); - RAY_CHECK(worker); - // If the worker was assigned a task, mark it as finished. - if (!worker->GetAssignedTaskId().IsNil()) { - FinishAssignedTask(*worker); +void NodeManager::HandleGetTaskRequest(const rpc::GetTaskRequest &request, + rpc::GetTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) { + const WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(worker_id); + RAY_LOG(DEBUG) << "Received a GetTaskRequest, worker id " << worker_id << " pid " + << worker->Pid(); + if (!worker || worker->IsBeingKilled()) { + send_reply_callback(Status::Invalid("WorkerBeingKilled"), nullptr, nullptr); + return; } + RAY_CHECK(!worker->UsePush()); - // Return the worker to the idle pool. - worker_pool_.PushWorker(std::move(worker)); - // Local resource availability changed: invoke scheduling policy for local node. - const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); - cluster_resource_map_[local_client_id].SetLoadResources( - local_queues_.GetResourceLoad()); - // Call task dispatch to assign work to the new worker. - DispatchTasks(local_queues_.GetReadyTasksWithResources()); + // Reply would be sent when assigned a task to the worker successfully. + worker->SetGetTaskReplyAndCallback(reply, std::move(send_reply_callback)); + HandleWorkerAvailable(worker_id); } -void NodeManager::ProcessDisconnectClientMessage( - const std::shared_ptr &client, bool intentional_disconnect) { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); +void NodeManager::HandleTaskDoneRequest(const rpc::TaskDoneRequest &request, + rpc::TaskDoneReply *reply, + rpc::SendReplyCallback send_reply_callback) { + const WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); + RAY_LOG(DEBUG) << "Received a TaskDoneRequest from worker " << worker_id; + + auto worker = worker_pool_.GetRegisteredWorker(worker_id); + RAY_CHECK(worker && worker->UsePush()); + HandleWorkerAvailable(worker_id); + send_reply_callback(Status::OK(), nullptr, nullptr); +} + +void NodeManager::HandleDisconnectClientRequest( + const rpc::DisconnectClientRequest &request, rpc::DisconnectClientReply *reply, + rpc::SendReplyCallback send_reply_callback) { + const WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); + RAY_LOG(DEBUG) << "Received a DisconnectClientRequest from worker " << worker_id; + + ProcessDisconnectClientMessage(worker_id, true); + send_reply_callback(Status::OK(), nullptr, nullptr); +} + +void NodeManager::ProcessDisconnectClientMessage(const WorkerID &worker_id, + bool intentional_disconnect) { + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(worker_id); bool is_worker = false, is_driver = false; if (worker) { // The client is a worker. is_worker = true; } else { - worker = worker_pool_.GetRegisteredDriver(client); + worker = worker_pool_.GetRegisteredDriver(worker_id); if (worker) { // The client is a driver. is_driver = true; } else { RAY_LOG(INFO) << "Ignoring client disconnect because the client has already " << "been disconnected."; + return; } } RAY_CHECK(!(is_worker && is_driver)); @@ -939,9 +879,9 @@ void NodeManager::ProcessDisconnectClientMessage( // If the client has any blocked tasks, mark them as unblocked. In // particular, we are no longer waiting for their dependencies. if (worker) { - if (is_worker && worker->IsDead()) { - // Don't need to unblock the client if it's a worker and is already dead. - // Because in this case, its task is already cleaned up. + if (is_worker && worker->IsBeingKilled()) { + // Don't need to unblock the client if it's a worker and have sent kill signal to + // it. Because in this case, its task is already cleaned up. RAY_LOG(DEBUG) << "Skip unblocking worker because it's already dead."; } else { // Clean up any open ray.get calls that the worker made. @@ -949,7 +889,7 @@ void NodeManager::ProcessDisconnectClientMessage( // NOTE(swang): HandleTaskUnblocked will modify the worker, so it is // not safe to pass in the iterator directly. const TaskID task_id = *worker->GetBlockedTaskIds().begin(); - HandleTaskUnblocked(client, task_id); + HandleTaskUnblocked(worker_id, task_id); } // Clean up any open ray.wait calls that the worker made. task_dependency_manager_.UnsubscribeWaitDependencies(worker->WorkerId()); @@ -957,24 +897,17 @@ void NodeManager::ProcessDisconnectClientMessage( } if (is_worker) { - // The client is a worker. - if (worker->IsDead()) { - // If the worker was killed by us because the driver exited, - // treat it as intentionally disconnected. - intentional_disconnect = true; - } - const ActorID &actor_id = worker->GetActorId(); if (!actor_id.IsNil()) { - // If the worker was an actor, update actor state, reconstruct the actor if needed, - // and clean up actor's tasks if the actor is permanently dead. + // If the worker was an actor, update actor state, reconstruct the actor if + // needed, and clean up actor's tasks if the actor is permanently dead. HandleDisconnectedActor(actor_id, true, intentional_disconnect); } const TaskID &task_id = worker->GetAssignedTaskId(); // If the worker was running a task, clean up the task and push an error to - // the driver, unless the worker is already dead. - if (!task_id.IsNil() && !worker->IsDead()) { + // the driver, unless the worker is already being killed. + if (!task_id.IsNil() && !worker->IsBeingKilled()) { // If the worker was an actor, the task was already cleaned up in // `HandleDisconnectedActor`. if (actor_id.IsNil()) { @@ -1035,32 +968,34 @@ void NodeManager::ProcessDisconnectClientMessage( << "job_id: " << job_id; } - client->Close(); - // TODO(rkn): Tell the object manager that this client has disconnected so // that it can clean up the wait requests for this client. Currently I think // these can be leaked. } -void NodeManager::ProcessSubmitTaskMessage(const uint8_t *message_data) { - // Read the task submitted by the client. - auto fbs_message = flatbuffers::GetRoot(message_data); - rpc::Task task_message; - RAY_CHECK(task_message.mutable_task_spec()->ParseFromArray( - fbs_message->task_spec()->data(), fbs_message->task_spec()->size())); +void NodeManager::HandleSubmitTaskRequest(const rpc::SubmitTaskRequest &request, + rpc::SubmitTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) { + RAY_LOG(DEBUG) << "Received a SubmitTaskRequest."; + + rpc::Task task; + task.mutable_task_spec()->CopyFrom(request.task_spec()); // Submit the task to the raylet. Since the task was submitted // locally, there is no uncommitted lineage. - SubmitTask(Task(task_message), Lineage()); + SubmitTask(Task(task), Lineage()); + send_reply_callback(Status::OK(), nullptr, nullptr); } -void NodeManager::ProcessFetchOrReconstructMessage( - const std::shared_ptr &client, const uint8_t *message_data) { - auto message = flatbuffers::GetRoot(message_data); +void NodeManager::HandleFetchOrReconstructRequest( + const rpc::FetchOrReconstructRequest &request, rpc::FetchOrReconstructReply *reply, + rpc::SendReplyCallback send_reply_callback) { + RAY_LOG(DEBUG) << "Received a FetchOrReconstructRequest."; + const auto &object_ids = request.object_ids(); std::vector required_object_ids; - for (size_t i = 0; i < message->object_ids()->size(); ++i) { - ObjectID object_id = from_flatbuf(*message->object_ids()->Get(i)); - if (message->fetch_only()) { + for (size_t i = 0; i < object_ids.size(); ++i) { + ObjectID object_id = ObjectID::FromBinary(object_ids[i]); + if (request.fetch_only()) { // If only a fetch is required, then do not subscribe to the // dependencies to the task dependency manager. if (!task_dependency_manager_.CheckObjectLocal(object_id)) { @@ -1077,19 +1012,22 @@ void NodeManager::ProcessFetchOrReconstructMessage( } if (!required_object_ids.empty()) { - const TaskID task_id = from_flatbuf(*message->task_id()); - HandleTaskBlocked(client, required_object_ids, task_id, /*ray_get=*/true); + const TaskID task_id = TaskID::FromBinary(request.task_id()); + const WorkerID &worker_id = WorkerID::FromBinary(request.worker_id()); + HandleTaskBlocked(worker_id, required_object_ids, task_id, /*ray_get=*/true); } + send_reply_callback(Status::OK(), nullptr, nullptr); } -void NodeManager::ProcessWaitRequestMessage( - const std::shared_ptr &client, const uint8_t *message_data) { +void NodeManager::HandleWaitRequest(const rpc::WaitRequest &request, + rpc::WaitReply *reply, + rpc::SendReplyCallback send_reply_callback) { + RAY_LOG(DEBUG) << "Received a WaitRequest."; // Read the data. - auto message = flatbuffers::GetRoot(message_data); - std::vector object_ids = from_flatbuf(*message->object_ids()); - int64_t wait_ms = message->timeout(); - uint64_t num_required_objects = static_cast(message->num_ready_objects()); - bool wait_local = message->wait_local(); + std::vector object_ids = IdVectorFromProtobuf(request.object_ids()); + int64_t wait_ms = request.timeout(); + uint64_t num_required_objects = request.num_ready_objects(); + bool wait_local = request.wait_local(); std::vector required_object_ids; for (auto const &object_id : object_ids) { @@ -1101,63 +1039,56 @@ void NodeManager::ProcessWaitRequestMessage( } } - const TaskID ¤t_task_id = from_flatbuf(*message->task_id()); + const TaskID ¤t_task_id = TaskID::FromBinary(request.task_id()); + const WorkerID &worker_id = WorkerID::FromBinary(request.worker_id()); bool client_blocked = !required_object_ids.empty(); if (client_blocked) { - HandleTaskBlocked(client, required_object_ids, current_task_id, /*ray_get=*/false); + HandleTaskBlocked(worker_id, required_object_ids, current_task_id, /*ray_get=*/false); } ray::Status status = object_manager_.Wait( object_ids, wait_ms, num_required_objects, wait_local, - [this, client_blocked, client, current_task_id](std::vector found, - std::vector remaining) { - // Write the data. - flatbuffers::FlatBufferBuilder fbb; - flatbuffers::Offset wait_reply = protocol::CreateWaitReply( - fbb, to_flatbuf(fbb, found), to_flatbuf(fbb, remaining)); - fbb.Finish(wait_reply); + [this, client_blocked, worker_id, current_task_id, reply, send_reply_callback]( + std::vector found, std::vector remaining) { + IdVectorToProtobuf(found, *reply, + &rpc::WaitReply::add_found); + IdVectorToProtobuf(remaining, *reply, + &rpc::WaitReply::add_remaining); - auto status = - client->WriteMessage(static_cast(protocol::MessageType::WaitReply), - fbb.GetSize(), fbb.GetBufferPointer()); - if (status.ok()) { - // The client is unblocked now because the wait call has returned. - if (client_blocked) { - HandleTaskUnblocked(client, current_task_id); - } - } else { - // We failed to write to the client, so disconnect the client. - RAY_LOG(WARNING) - << "Failed to send WaitReply to client, so disconnecting client"; - // We failed to send the reply to the client, so disconnect the worker. - ProcessDisconnectClientMessage(client); + // Send reply to finish this wait request. + send_reply_callback(Status::OK(), nullptr, nullptr); + if (client_blocked) { + HandleTaskUnblocked(worker_id, current_task_id); } }); RAY_CHECK_OK(status); } -void NodeManager::ProcessPushErrorRequestMessage(const uint8_t *message_data) { - auto message = flatbuffers::GetRoot(message_data); - - JobID job_id = from_flatbuf(*message->job_id()); - auto const &type = string_from_flatbuf(*message->type()); - auto const &error_message = string_from_flatbuf(*message->error_message()); - double timestamp = message->timestamp(); +void NodeManager::HandlePushErrorRequest(const rpc::PushErrorRequest &request, + rpc::PushErrorReply *reply, + rpc::SendReplyCallback send_reply_callback) { + JobID job_id = JobID::FromBinary(request.job_id()); + const auto &type = request.type(); + const auto &error_message = request.error_message(); + double timestamp = request.timestamp(); + RAY_LOG(DEBUG) << "Handle push error request for job " << job_id << ", type " << type + << " error message " << error_message; RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(job_id, type, error_message, timestamp)); + send_reply_callback(Status::OK(), nullptr, nullptr); } -void NodeManager::ProcessPrepareActorCheckpointRequest( - const std::shared_ptr &client, const uint8_t *message_data) { - auto message = - flatbuffers::GetRoot(message_data); - ActorID actor_id = from_flatbuf(*message->actor_id()); +void NodeManager::HandlePrepareActorCheckpointRequest( + const rpc::PrepareActorCheckpointRequest &request, + rpc::PrepareActorCheckpointReply *reply, rpc::SendReplyCallback send_reply_callback) { + ActorID actor_id = ActorID::FromBinary(request.actor_id()); RAY_LOG(DEBUG) << "Preparing checkpoint for actor " << actor_id; const auto &actor_entry = actor_registry_.find(actor_id); RAY_CHECK(actor_entry != actor_registry_.end()); - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(worker_id); RAY_CHECK(worker && worker->GetActorId() == actor_id); // Find the task that is running on this actor. @@ -1171,44 +1102,34 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( // Write checkpoint data to GCS. RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Add( JobID::Nil(), checkpoint_id, checkpoint_data, - [worker, actor_id, this](ray::gcs::RedisGcsClient *client, - const ActorCheckpointID &checkpoint_id, - const ActorCheckpointData &data) { + [worker, actor_id, reply, send_reply_callback, this]( + ray::gcs::RedisGcsClient *client, const ActorCheckpointID &checkpoint_id, + const ActorCheckpointData &data) { RAY_LOG(DEBUG) << "Checkpoint " << checkpoint_id << " saved for actor " << worker->GetActorId(); - // Save this actor-to-checkpoint mapping, and remove old checkpoints associated - // with this actor. + // Save this actor-to-checkpoint mapping, and remove old checkpoints + // associated with this actor. RAY_CHECK_OK(gcs_client_->actor_checkpoint_id_table().AddCheckpointId( JobID::Nil(), actor_id, checkpoint_id)); // Send reply to worker. - flatbuffers::FlatBufferBuilder fbb; - auto reply = ray::protocol::CreatePrepareActorCheckpointReply( - fbb, to_flatbuf(fbb, checkpoint_id)); - fbb.Finish(reply); - worker->Connection()->WriteMessageAsync( - static_cast(protocol::MessageType::PrepareActorCheckpointReply), - fbb.GetSize(), fbb.GetBufferPointer(), [](const ray::Status &status) { - if (!status.ok()) { - RAY_LOG(WARNING) - << "Failed to send PrepareActorCheckpointReply to client"; - } - }); + reply->set_checkpoint_id(checkpoint_id.Binary()); + send_reply_callback(Status::OK(), nullptr, []() { + RAY_LOG(WARNING) << "Failed to send PrepareActorCheckpointReply to client"; + }); })); } -void NodeManager::ProcessNotifyActorResumedFromCheckpoint(const uint8_t *message_data) { - auto message = - flatbuffers::GetRoot(message_data); - ActorID actor_id = from_flatbuf(*message->actor_id()); +void NodeManager::HandleNotifyActorResumedFromCheckpointRequest( + const rpc::NotifyActorResumedFromCheckpointRequest &request, + rpc::NotifyActorResumedFromCheckpointReply *reply, + rpc::SendReplyCallback send_reply_callback) { + ActorID actor_id = ActorID::FromBinary(request.actor_id()); ActorCheckpointID checkpoint_id = - from_flatbuf(*message->checkpoint_id()); + ActorCheckpointID::FromBinary(request.checkpoint_id()); RAY_LOG(DEBUG) << "Actor " << actor_id << " was resumed from checkpoint " << checkpoint_id; checkpoint_id_to_restore_.emplace(actor_id, checkpoint_id); -} - -void NodeManager::ProcessNewNodeManager(TcpClientConnection &node_manager_client) { - node_manager_client.ProcessMessages(); + send_reply_callback(Status::OK(), nullptr, nullptr); } void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request, @@ -1229,16 +1150,14 @@ void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request, send_reply_callback(Status::OK(), nullptr, nullptr); } -void NodeManager::ProcessSetResourceRequest( - const std::shared_ptr &client, const uint8_t *message_data) { - // Read the SetResource message - auto message = flatbuffers::GetRoot(message_data); - - auto const &resource_name = string_from_flatbuf(*message->resource_name()); - double const &capacity = message->capacity(); +void NodeManager::HandleSetResourceRequest(const rpc::SetResourceRequest &request, + rpc::SetResourceReply *reply, + rpc::SendReplyCallback send_reply_callback) { + auto const &resource_name = request.resource_name(); + double const capacity = request.capacity(); bool is_deletion = capacity <= 0; - ClientID client_id = from_flatbuf(*message->client_id()); + ClientID client_id = ClientID::FromBinary(request.client_id()); // If the python arg was null, set client_id to the local client if (client_id.IsNil()) { @@ -1255,7 +1174,7 @@ void NodeManager::ProcessSetResourceRequest( return; } - // Submit to the client table. This calls the ResourceCreateUpdated or ResourceDeleted + // Submit to the resource table. This calls the ResourceCreateUpdated or ResourceDeleted // callback, which updates cluster_resource_map_. if (is_deletion) { RAY_CHECK_OK(gcs_client_->resource_table().RemoveEntries(JobID::Nil(), client_id, @@ -1268,6 +1187,62 @@ void NodeManager::ProcessSetResourceRequest( RAY_CHECK_OK( gcs_client_->resource_table().Update(JobID::Nil(), client_id, data_map, nullptr)); } + send_reply_callback(Status::OK(), nullptr, nullptr); +} + +void NodeManager::HandleNotifyUnblockedRequest( + const rpc::NotifyUnblockedRequest &request, rpc::NotifyUnblockedReply *reply, + rpc::SendReplyCallback send_reply_callback) { + RAY_LOG(DEBUG) << "Received a NotifyUnblockedRequest."; + const TaskID current_task_id = TaskID::FromBinary(request.task_id()); + const WorkerID worker_id = WorkerID::FromBinary(request.worker_id()); + + HandleTaskUnblocked(worker_id, current_task_id); + send_reply_callback(Status::OK(), nullptr, nullptr); +} + +void NodeManager::HandlePushProfileEventsRequest( + const rpc::PushProfileEventsRequest &request, rpc::PushProfileEventsReply *reply, + rpc::SendReplyCallback send_reply_callback) { + RAY_LOG(DEBUG) << "Received a PushProfileEventsRequest."; + const auto &profile_table_data = request.profile_table_data(); + RAY_CHECK_OK(gcs_client_->profile_table().AddProfileEventBatch(profile_table_data)); + send_reply_callback(Status::OK(), nullptr, nullptr); +} + +void NodeManager::HandleFreeObjectsInStoreRequest( + const rpc::FreeObjectsInStoreRequest &request, rpc::FreeObjectsInStoreReply *reply, + rpc::SendReplyCallback send_reply_callback) { + RAY_LOG(DEBUG) << "Received a FreeObjectsInStoreRequest."; + std::vector object_ids = IdVectorFromProtobuf(request.object_ids()); + object_manager_.FreeObjects(object_ids, request.local_only()); + if (request.delete_creating_tasks()) { + // Clean up their creating tasks from GCS. + std::vector creating_task_ids; + for (const auto &object_id : object_ids) { + creating_task_ids.push_back(object_id.TaskId()); + } + gcs_client_->raylet_task_table().Delete(JobID::Nil(), creating_task_ids); + } + send_reply_callback(Status::OK(), nullptr, nullptr); +} + +void NodeManager::HandleHeartbeatRequest(const rpc::HeartbeatRequest &request, + rpc::HeartbeatReply *reply, + rpc::SendReplyCallback send_reply_callback) { + bool is_worker = request.is_worker(); + const auto worker_id = WorkerID::FromBinary(request.worker_id()); + + std::shared_ptr worker = nullptr; + if (is_worker) { + worker = worker_pool_.GetRegisteredWorker(worker_id); + } else { + worker = worker_pool_.GetRegisteredDriver(worker_id); + } + if (worker) { + worker->ClearHeartbeat(); + } + send_reply_callback(Status::OK(), nullptr, nullptr); } void NodeManager::ScheduleTasks( @@ -1574,7 +1549,8 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag // This is a non-actor task. Queue the task for a placement decision or for dispatch // if the task was forwarded. if (forwarded) { - // Check for local dependencies and enqueue as waiting or ready for dispatch. + // Check for local dependencies and enqueue as waiting or ready for + // dispatch. EnqueuePlaceableTask(task); } else { // (See design_docs/task_states.rst for the state transition diagram.) @@ -1586,10 +1562,10 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag } } -void NodeManager::HandleTaskBlocked(const std::shared_ptr &client, +void NodeManager::HandleTaskBlocked(const WorkerID &worker_id, const std::vector &required_object_ids, const TaskID ¤t_task_id, bool ray_get) { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(worker_id); if (worker) { // The client is a worker. If the worker is not already blocked and the // blocked task matches the one assigned to the worker, then mark the @@ -1616,7 +1592,7 @@ void NodeManager::HandleTaskBlocked(const std::shared_ptr } else { // The client is a driver. Drivers do not hold resources, so we simply mark // the task as blocked. - worker = worker_pool_.GetRegisteredDriver(client); + worker = worker_pool_.GetRegisteredDriver(worker_id); } RAY_CHECK(worker); @@ -1638,9 +1614,9 @@ void NodeManager::HandleTaskBlocked(const std::shared_ptr } } -void NodeManager::HandleTaskUnblocked( - const std::shared_ptr &client, const TaskID ¤t_task_id) { - std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); +void NodeManager::HandleTaskUnblocked(const WorkerID &worker_id, + const TaskID ¤t_task_id) { + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(worker_id); // TODO(swang): Because the object dependencies are tracked in the task // dependency manager, we could actually remove this message entirely and @@ -1670,9 +1646,9 @@ void NodeManager::HandleTaskUnblocked( cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()].Acquire( cpu_resources); } else { - // In this case, we simply don't reacquire the CPU resources for the worker. - // The worker can keep running and when the task finishes, it will simply - // not have any CPU resources to release. + // In this case, we simply don't reacquire the CPU resources for the + // worker. The worker can keep running and when the task finishes, it will + // simply not have any CPU resources to release. RAY_LOG(WARNING) << "Resources oversubscribed: " << cluster_resource_map_[gcs_client_->client_table().GetLocalClientId()] @@ -1684,7 +1660,7 @@ void NodeManager::HandleTaskUnblocked( } else { // The client is a driver. Drivers do not hold resources, so we simply // mark the driver as unblocked. - worker = worker_pool_.GetRegisteredDriver(client); + worker = worker_pool_.GetRegisteredDriver(worker_id); } // Unsubscribe from any `ray.get` objects that the task was blocked on. Any @@ -1719,6 +1695,24 @@ void NodeManager::EnqueuePlaceableTask(const Task &task) { task_dependency_manager_.TaskPending(task); } +void NodeManager::HandleWorkerAvailable(const WorkerID &worker_id) { + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(worker_id); + RAY_CHECK(worker); + // If the worker was assigned a task, mark it as finished. + if (!worker->GetAssignedTaskId().IsNil()) { + FinishAssignedTask(*worker); + } + + // Return the worker to the idle pool. + worker_pool_.PushWorker(std::move(worker)); + // Local resource availability changed: invoke scheduling policy for local node. + const ClientID &local_client_id = gcs_client_->client_table().GetLocalClientId(); + cluster_resource_map_[local_client_id].SetLoadResources( + local_queues_.GetResourceLoad()); + // Call task dispatch to assign work to the new worker. + DispatchTasks(local_queues_.GetReadyTasksWithResources()); +} + bool NodeManager::AssignTask(const Task &task) { const TaskSpecification &spec = task.GetTaskSpecification(); @@ -1738,12 +1732,12 @@ bool NodeManager::AssignTask(const Task &task) { if (worker == nullptr) { // There are no workers that can execute this task. // We couldn't assign this task, as no worker available. + RAY_LOG(DEBUG) << "No idle worker is found to assign task " << spec.TaskId(); return false; } - RAY_LOG(DEBUG) << "Assigning task " << spec.TaskId() << " to worker with pid " - << worker->Pid(); - flatbuffers::FlatBufferBuilder fbb; + RAY_LOG(DEBUG) << "Assigning task " << spec.TaskId() << " to worker " + << worker->GetWorkerId() << " pid " << worker->Pid(); // Resource accounting: acquire resources for the assigned task. auto acquired_resources = @@ -1760,40 +1754,44 @@ bool NodeManager::AssignTask(const Task &task) { worker->SetTaskResourceIds(acquired_resources); } - auto task_id = spec.TaskId(); - auto finish_assign_task_callback = [this, worker, task_id](Status status) { - if (worker->UsePush()) { - // NOTE: we cannot directly call `FinishAssignTask` here because - // it assumes the task is in SWAP queue, thus we need to delay invoking this - // function after the assigned tasks are moved from READY queue to SWAP queue - // in `DispatchTasks`. - // Another option is to move the tasks to SWAP queue here just before calling - // `FinishAssignTask` so we can save an io_service post, at the - // expense of calling `MoveTask` for each of the assigned tasks. - // TODO(zhijunfu): after all workers are fully migrated to push mode, the - // `post` below and swap queue can be removed. - io_service_.post([this, status, worker, task_id]() { - FinishAssignTask(task_id, *worker, status.ok()); - }); - } else { - FinishAssignTask(task_id, *worker, status.ok()); - } - }; - ResourceIdSet resource_id_set = worker->GetTaskResourceIds().Plus(worker->GetLifetimeResourceIds()); - worker->AssignTask(task, resource_id_set, finish_assign_task_callback); + worker->AssignTask(task, resource_id_set); + // Actor tasks require extra accounting to track the actor's state. + if (spec.IsActorTask()) { + auto actor_entry = actor_registry_.find(spec.ActorId()); + RAY_CHECK(actor_entry != actor_registry_.end()); + // Process any new actor handles that were created since the + // previous task on this handle was executed. The first task + // submitted on a new actor handle will depend on the dummy object + // returned by the previous task, so the dependency will not be + // released until this first task is submitted. + for (auto &new_handle_id : spec.NewActorHandles()) { + const auto prev_actor_task_id = spec.PreviousActorTaskDummyObjectId(); + RAY_CHECK(!prev_actor_task_id.IsNil()); + // Add the new handle and give it a reference to the finished task's + // execution dependency. + actor_entry->second.AddHandle(new_handle_id, prev_actor_task_id); + } + + // TODO(swang): For actors with multiple actor handles, to + // guarantee that tasks are replayed in the same order after a + // failure, we must update the task's execution dependency to be + // the actor's current execution dependency. + } + + // Notify the task dependency manager that we no longer need this task's + // object dependencies. + RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId())); - // We assigned this task to a worker. - // (Note this means that we sent the task to the worker. The assignment - // might still fail if the worker fails in the meantime, for instance.) return true; } void NodeManager::FinishAssignedTask(Worker &worker) { TaskID task_id = worker.GetAssignedTaskId(); - RAY_LOG(DEBUG) << "Finished task " << task_id; + RAY_LOG(DEBUG) << "Finished task " << task_id << " from worker " << worker.GetWorkerId() + << " with pid " << worker.Pid(); // (See design_docs/task_states.rst for the state transition diagram.) Task task; @@ -2202,8 +2200,9 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, retry_timer->expires_from_now(retry_duration); retry_timer->async_wait( [this, task_id, retry_timer](const boost::system::error_code &error) { - // Timer killing will receive the boost::asio::error::operation_aborted, - // we only handle the timeout event. + // Timer killing will receive the + // boost::asio::error::operation_aborted, we only handle the timeout + // event. RAY_CHECK(!error); RAY_LOG(INFO) << "Resubmitting task " << task_id << " because ForwardTask failed."; @@ -2325,63 +2324,6 @@ void NodeManager::ForwardTask( }); } -void NodeManager::FinishAssignTask(const TaskID &task_id, Worker &worker, bool success) { - // Remove the ASSIGNED task from the SWAP queue. - Task assigned_task; - TaskState state; - if (!local_queues_.RemoveTask(task_id, &assigned_task, &state)) { - return; - } - - RAY_CHECK(state == TaskState::SWAP); - - if (success) { - auto spec = assigned_task.GetTaskSpecification(); - // We successfully assigned the task to the worker. - worker.AssignTaskId(spec.TaskId()); - worker.AssignJobId(spec.JobId()); - // Actor tasks require extra accounting to track the actor's state. - if (spec.IsActorTask()) { - auto actor_entry = actor_registry_.find(spec.ActorId()); - RAY_CHECK(actor_entry != actor_registry_.end()); - // Process any new actor handles that were created since the - // previous task on this handle was executed. The first task - // submitted on a new actor handle will depend on the dummy object - // returned by the previous task, so the dependency will not be - // released until this first task is submitted. - for (auto &new_handle_id : spec.NewActorHandles()) { - const auto prev_actor_task_id = spec.PreviousActorTaskDummyObjectId(); - RAY_CHECK(!prev_actor_task_id.IsNil()); - // Add the new handle and give it a reference to the finished task's - // execution dependency. - actor_entry->second.AddHandle(new_handle_id, prev_actor_task_id); - } - - // TODO(swang): For actors with multiple actor handles, to - // guarantee that tasks are replayed in the same order after a - // failure, we must update the task's execution dependency to be - // the actor's current execution dependency. - } - - // Mark the task as running. - // (See design_docs/task_states.rst for the state transition diagram.) - local_queues_.QueueTasks({assigned_task}, TaskState::RUNNING); - // Notify the task dependency manager that we no longer need this task's - // object dependencies. - RAY_CHECK(task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId())); - } else { - RAY_LOG(WARNING) << "Failed to send task to worker, disconnecting client"; - // We failed to send the task to the worker, so disconnect the worker. - ProcessDisconnectClientMessage(worker.Connection()); - // Queue this task for future assignment. We need to do this since - // DispatchTasks() removed it from the ready queue. The task will be - // assigned to a worker once one becomes available. - // (See design_docs/task_states.rst for the state transition diagram.) - local_queues_.QueueTasks({assigned_task}, TaskState::READY); - DispatchTasks(MakeTasksWithResources({assigned_task})); - } -} - void NodeManager::DumpDebugState() const { std::fstream fs; fs.open(initial_config_.session_dir + "/debug_state.txt", diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 19dc682c3..255058783 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -7,6 +7,8 @@ #include "ray/rpc/client_call.h" #include "ray/rpc/node_manager/node_manager_server.h" #include "ray/rpc/node_manager/node_manager_client.h" +#include "ray/rpc/raylet/raylet_server.h" +#include "ray/object_manager/object_manager.h" #include "ray/common/task/task.h" #include "ray/common/client_connection.h" #include "ray/common/task/task_common.h" @@ -64,7 +66,8 @@ struct NodeManagerConfig { std::string session_dir; }; -class NodeManager : public rpc::NodeManagerServiceHandler { +class NodeManager : public rpc::NodeManagerServiceHandler, + public rpc::RayletServiceHandler { public: /// Create a node manager. /// @@ -75,29 +78,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler { std::shared_ptr gcs_client, std::shared_ptr object_directory_); - /// Process a new client connection. - /// - /// \param client The client to process. - /// \return Void. - void ProcessNewClient(LocalClientConnection &client); - - /// Process a message from a client. This method is responsible for - /// explicitly listening for more messages from the client if the client is - /// still alive. - /// - /// \param client The client that sent the message. - /// \param message_type The message type (e.g., a flatbuffer enum). - /// \param message_data A pointer to the message data. - /// \return Void. - void ProcessClientMessage(const std::shared_ptr &client, - int64_t message_type, const uint8_t *message_data); - - /// Handle a new node manager connection. - /// - /// \param node_manager_client The connection to the remote node manager. - /// \return Void. - void ProcessNewNodeManager(TcpClientConnection &node_manager_client); - /// Subscribe to the relevant GCS tables and set up handlers. /// /// \return Status indicating whether this was done successfully or not. @@ -111,9 +91,83 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// Record metrics. void RecordMetrics() const; + public: /// Get the port of the node manager rpc server. int GetServerPort() const { return node_manager_server_.GetPort(); } + /// Implementation of node manager grpc service. + + /// Handle a `ForwardTask` request. + /// + /// \param request The request. + /// \param reply The reply that will be sent to client. + /// \param send_reply_callback Invoke this callback to send reply asynchronously. + void HandleForwardTask(const rpc::ForwardTaskRequest &request, + rpc::ForwardTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + /// Implementation of raylet grpc service handlers, please see definitions + /// in `src/ray/protobuf/raylet.proto` for more details. + void HandleRegisterClientRequest(const rpc::RegisterClientRequest &request, + rpc::RegisterClientReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleSubmitTaskRequest(const rpc::SubmitTaskRequest &request, + rpc::SubmitTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleDisconnectClientRequest(const rpc::DisconnectClientRequest &request, + rpc::DisconnectClientReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleGetTaskRequest(const rpc::GetTaskRequest &request, rpc::GetTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleTaskDoneRequest(const rpc::TaskDoneRequest &request, + rpc::TaskDoneReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleFetchOrReconstructRequest( + const rpc::FetchOrReconstructRequest &request, rpc::FetchOrReconstructReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleNotifyUnblockedRequest(const rpc::NotifyUnblockedRequest &request, + rpc::NotifyUnblockedReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleWaitRequest(const rpc::WaitRequest &request, rpc::WaitReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandlePushErrorRequest(const rpc::PushErrorRequest &request, + rpc::PushErrorReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandlePushProfileEventsRequest( + const rpc::PushProfileEventsRequest &request, rpc::PushProfileEventsReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleFreeObjectsInStoreRequest( + const rpc::FreeObjectsInStoreRequest &request, rpc::FreeObjectsInStoreReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandlePrepareActorCheckpointRequest( + const rpc::PrepareActorCheckpointRequest &request, + rpc::PrepareActorCheckpointReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleNotifyActorResumedFromCheckpointRequest( + const rpc::NotifyActorResumedFromCheckpointRequest &request, + rpc::NotifyActorResumedFromCheckpointReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleSetResourceRequest(const rpc::SetResourceRequest &request, + rpc::SetResourceReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + + void HandleHeartbeatRequest(const rpc::HeartbeatRequest &request, + rpc::HeartbeatReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + private: /// Methods for handling clients. @@ -214,6 +268,11 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \return Void. void SubmitTask(const Task &task, const Lineage &uncommitted_lineage, bool forwarded = false); + /// Handle the case that a worker is available. + /// + /// \param id Id of the worker. + /// \return Void. + void HandleWorkerAvailable(const WorkerID &worker_id); /// Assign a task. The task is assumed to not be queued in local_queues_. /// /// \param task The task in question. @@ -322,7 +381,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param ray_get Whether the task is blocked in a `ray.get` call, as /// opposed to a `ray.wait` call. /// \return Void. - void HandleTaskBlocked(const std::shared_ptr &client, + void HandleTaskBlocked(const WorkerID &worker_id, const std::vector &required_object_ids, const TaskID ¤t_task_id, bool ray_get); @@ -335,8 +394,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \param client The client that is executing the unblocked task. /// \param current_task_id The task that is unblocked. /// \return Void. - void HandleTaskUnblocked(const std::shared_ptr &client, - const TaskID ¤t_task_id); + void HandleTaskUnblocked(const WorkerID &worker_id, const TaskID ¤t_task_id); /// Kill a worker. /// @@ -410,11 +468,10 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// client. /// /// \param client The client that sent the message. - /// \param intentional_disconnect Wether the client was intentionally disconnected. + /// \param intentional_disconnect Whether the client was intentionally disconnected. /// \return Void. - void ProcessDisconnectClientMessage( - const std::shared_ptr &client, - bool intentional_disconnect = false); + void ProcessDisconnectClientMessage(const WorkerID &worker_id, + bool intentional_disconnect = false); /// Process client message of SubmitTask /// @@ -482,19 +539,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler { void HandleDisconnectedActor(const ActorID &actor_id, bool was_local, bool intentional_disconnect); - /// Finish assigning a task to a worker. - /// - /// \param task_id Id of the task. - /// \param worker Worker which the task is assigned to. - /// \param success Whether the task is successfully assigned to the worker. - /// \return void. - void FinishAssignTask(const TaskID &task_id, Worker &worker, bool success); - - /// Handle a `ForwardTask` request. - void HandleForwardTask(const rpc::ForwardTaskRequest &request, - rpc::ForwardTaskReply *reply, - rpc::SendReplyCallback send_reply_callback) override; - // GCS client ID for this node. ClientID client_id_; boost::asio::io_service &io_service_; @@ -551,9 +595,12 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// The RPC server. rpc::GrpcServer node_manager_server_; - /// The RPC service. + /// The node manager RPC service. rpc::NodeManagerGrpcService node_manager_service_; + /// The raylet RPC service. + rpc::RayletGrpcService raylet_service_; + /// The `ClientCallManager` object that is shared by all `NodeManagerClient`s /// as well as all `WorkerTaskClient`s. rpc::ClientCallManager client_call_manager_; diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index 72d9ab799..0322e64a3 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -7,34 +7,6 @@ #include "ray/common/status.h" -namespace { - -const std::vector GenerateEnumNames(const char *const *enum_names_ptr, - int start_index, int end_index) { - std::vector enum_names; - for (int i = 0; i < start_index; ++i) { - enum_names.push_back("EmptyMessageType"); - } - size_t i = 0; - while (true) { - const char *name = enum_names_ptr[i]; - if (name == nullptr) { - break; - } - enum_names.push_back(name); - i++; - } - RAY_CHECK(static_cast(end_index) == enum_names.size() - 1) - << "Message Type mismatch!"; - return enum_names; -} - -static const std::vector node_manager_message_enum = - GenerateEnumNames(ray::protocol::EnumNamesMessageType(), - static_cast(ray::protocol::MessageType::MIN), - static_cast(ray::protocol::MessageType::MAX)); -} // namespace - namespace ray { namespace raylet { @@ -51,10 +23,10 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_ node_manager_(main_service, node_manager_config, object_manager_, gcs_client_, object_directory_), socket_name_(socket_name), - acceptor_(main_service, boost::asio::local::stream_protocol::endpoint(socket_name)), - socket_(main_service) { - // Start listening for clients. - DoAccept(); + raylet_server_("Raylet", socket_name), + raylet_service_(main_service, node_manager_) { + raylet_server_.RegisterService(raylet_service_); + raylet_server_.Run(); RAY_CHECK_OK(RegisterGcs( node_ip_address, socket_name_, object_manager_config.store_socket_name, @@ -108,31 +80,6 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, return Status::OK(); } -void Raylet::DoAccept() { - acceptor_.async_accept(socket_, boost::bind(&Raylet::HandleAccept, this, - boost::asio::placeholders::error)); -} - -void Raylet::HandleAccept(const boost::system::error_code &error) { - if (!error) { - // TODO: typedef these handlers. - ClientHandler client_handler = - [this](LocalClientConnection &client) { node_manager_.ProcessNewClient(client); }; - MessageHandler message_handler = - [this](std::shared_ptr client, int64_t message_type, - const uint8_t *message) { - node_manager_.ProcessClientMessage(client, message_type, message); - }; - // Accept a new local client and dispatch it to the node manager. - auto new_connection = LocalClientConnection::Create( - client_handler, message_handler, std::move(socket_), "worker", - node_manager_message_enum, - static_cast(protocol::MessageType::DisconnectClient)); - } - // We're ready to accept another client. - DoAccept(); -} - } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index 39e226a77..d39362f70 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -22,7 +22,7 @@ class NodeManager; class Raylet { public: - /// Create a node manager server and listen for new clients. + /// Create a raylet server and listen for local clients. /// /// \param main_service The event loop to run the server on. /// \param object_manager_service The asio io_service tied to the object manager. @@ -75,10 +75,12 @@ class Raylet { /// The name of the socket this raylet listens on. std::string socket_name_; - /// An acceptor for new clients. - boost::asio::local::stream_protocol::acceptor acceptor_; - /// The socket to listen on for new clients. - boost::asio::local::stream_protocol::socket socket_; + /// The gRPC server, listens for local raylet client connections through a unix domain + /// socket. + rpc::GrpcServer raylet_server_; + + /// The gRPC service. + rpc::RayletGrpcService raylet_service_; }; } // namespace raylet diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc deleted file mode 100644 index 87a8af4fb..000000000 --- a/src/ray/raylet/raylet_client.cc +++ /dev/null @@ -1,396 +0,0 @@ -#include "raylet_client.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "ray/common/common_protocol.h" -#include "ray/common/ray_config.h" -#include "ray/common/task/task_spec.h" -#include "ray/raylet/format/node_manager_generated.h" -#include "ray/util/logging.h" - -using MessageType = ray::protocol::MessageType; - -// TODO(rkn): The io methods below should be removed. -int connect_ipc_sock(const std::string &socket_pathname) { - struct sockaddr_un socket_address; - int socket_fd; - - socket_fd = socket(AF_UNIX, SOCK_STREAM, 0); - if (socket_fd < 0) { - RAY_LOG(ERROR) << "socket() failed for pathname " << socket_pathname; - return -1; - } - - memset(&socket_address, 0, sizeof(socket_address)); - socket_address.sun_family = AF_UNIX; - if (socket_pathname.length() + 1 > sizeof(socket_address.sun_path)) { - RAY_LOG(ERROR) << "Socket pathname is too long."; - close(socket_fd); - return -1; - } - strncpy(socket_address.sun_path, socket_pathname.c_str(), socket_pathname.length() + 1); - - if (connect(socket_fd, (struct sockaddr *)&socket_address, sizeof(socket_address)) != - 0) { - close(socket_fd); - return -1; - } - return socket_fd; -} - -int read_bytes(int socket_fd, uint8_t *cursor, size_t length) { - ssize_t nbytes = 0; - // Termination condition: EOF or read 'length' bytes total. - size_t bytesleft = length; - size_t offset = 0; - while (bytesleft > 0) { - nbytes = read(socket_fd, cursor + offset, bytesleft); - if (nbytes < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { - continue; - } - return -1; // Errno will be set. - } else if (0 == nbytes) { - // Encountered early EOF. - return -1; - } - RAY_CHECK(nbytes > 0); - bytesleft -= nbytes; - offset += nbytes; - } - return 0; -} - -int write_bytes(int socket_fd, uint8_t *cursor, size_t length) { - ssize_t nbytes = 0; - size_t bytesleft = length; - size_t offset = 0; - while (bytesleft > 0) { - // While we haven't written the whole message, write to the file - // descriptor, advance the cursor, and decrease the amount left to write. - nbytes = write(socket_fd, cursor + offset, bytesleft); - if (nbytes < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) { - continue; - } - return -1; // Errno will be set. - } else if (0 == nbytes) { - // Encountered early EOF. - return -1; - } - RAY_CHECK(nbytes > 0); - bytesleft -= nbytes; - offset += nbytes; - } - return 0; -} - -RayletConnection::RayletConnection(const std::string &raylet_socket, int num_retries, - int64_t timeout) { - // Pick the default values if the user did not specify. - if (num_retries < 0) { - num_retries = RayConfig::instance().num_connect_attempts(); - } - if (timeout < 0) { - timeout = RayConfig::instance().connect_timeout_milliseconds(); - } - RAY_CHECK(!raylet_socket.empty()); - conn_ = -1; - for (int num_attempts = 0; num_attempts < num_retries; ++num_attempts) { - conn_ = connect_ipc_sock(raylet_socket); - if (conn_ >= 0) break; - if (num_attempts > 0) { - RAY_LOG(ERROR) << "Retrying to connect to socket for pathname " << raylet_socket - << " (num_attempts = " << num_attempts - << ", num_retries = " << num_retries << ")"; - } - // Sleep for timeout milliseconds. - usleep(timeout * 1000); - } - // If we could not connect to the socket, exit. - if (conn_ == -1) { - RAY_LOG(FATAL) << "Could not connect to socket " << raylet_socket; - } -} - -ray::Status RayletConnection::Disconnect() { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateDisconnectClient(fbb); - fbb.Finish(message); - auto status = WriteMessage(MessageType::IntentionalDisconnectClient, &fbb); - // Don't be too strict for disconnection errors. - // Just create logs and prevent it from crash. - if (!status.ok()) { - RAY_LOG(ERROR) << status.ToString() - << " [RayletClient] Failed to disconnect from raylet."; - } - return ray::Status::OK(); -} - -ray::Status RayletConnection::ReadMessage(MessageType type, - std::unique_ptr &message) { - int64_t cookie; - int64_t type_field; - int64_t length; - int closed = read_bytes(conn_, (uint8_t *)&cookie, sizeof(cookie)); - if (closed) goto disconnected; - RAY_CHECK(cookie == RayConfig::instance().ray_cookie()); - closed = read_bytes(conn_, (uint8_t *)&type_field, sizeof(type_field)); - if (closed) goto disconnected; - closed = read_bytes(conn_, (uint8_t *)&length, sizeof(length)); - if (closed) goto disconnected; - message = std::unique_ptr(new uint8_t[length]); - closed = read_bytes(conn_, message.get(), length); - if (closed) { - // Handle the case in which the socket is closed. - message.reset(nullptr); - disconnected: - message = nullptr; - type_field = static_cast(MessageType::DisconnectClient); - length = 0; - } - if (type_field == static_cast(MessageType::DisconnectClient)) { - return ray::Status::IOError("[RayletClient] Raylet connection closed."); - } - if (type_field != static_cast(type)) { - return ray::Status::TypeError( - std::string("[RayletClient] Raylet connection corrupted. ") + - "Expected message type: " + std::to_string(static_cast(type)) + - "; got message type: " + std::to_string(type_field) + - ". Check logs or dmesg for previous errors."); - } - return ray::Status::OK(); -} - -ray::Status RayletConnection::WriteMessage(MessageType type, - flatbuffers::FlatBufferBuilder *fbb) { - std::unique_lock guard(write_mutex_); - int64_t cookie = RayConfig::instance().ray_cookie(); - int64_t length = fbb ? fbb->GetSize() : 0; - uint8_t *bytes = fbb ? fbb->GetBufferPointer() : nullptr; - int64_t type_field = static_cast(type); - auto io_error = ray::Status::IOError("[RayletClient] Connection closed unexpectedly."); - int closed; - closed = write_bytes(conn_, (uint8_t *)&cookie, sizeof(cookie)); - if (closed) return io_error; - closed = write_bytes(conn_, (uint8_t *)&type_field, sizeof(type_field)); - if (closed) return io_error; - closed = write_bytes(conn_, (uint8_t *)&length, sizeof(length)); - if (closed) return io_error; - closed = write_bytes(conn_, bytes, length * sizeof(char)); - if (closed) return io_error; - return ray::Status::OK(); -} - -ray::Status RayletConnection::AtomicRequestReply( - MessageType request_type, MessageType reply_type, - std::unique_ptr &reply_message, flatbuffers::FlatBufferBuilder *fbb) { - std::unique_lock guard(mutex_); - auto status = WriteMessage(request_type, fbb); - if (!status.ok()) return status; - return ReadMessage(reply_type, reply_message); -} - -RayletClient::RayletClient(const std::string &raylet_socket, const ClientID &client_id, - bool is_worker, const JobID &job_id, const Language &language, - int port) - : client_id_(client_id), - is_worker_(is_worker), - job_id_(job_id), - language_(language), - port_(port) { - // For C++14, we could use std::make_unique - conn_ = std::unique_ptr(new RayletConnection(raylet_socket, -1, -1)); - - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateRegisterClientRequest( - fbb, is_worker, to_flatbuf(fbb, client_id), getpid(), to_flatbuf(fbb, job_id), - language, port); - fbb.Finish(message); - // Register the process ID with the raylet. - // NOTE(swang): If raylet exits and we are registered as a worker, we will get killed. - auto status = conn_->WriteMessage(MessageType::RegisterClientRequest, &fbb); - RAY_CHECK_OK_PREPEND(status, "[RayletClient] Unable to register worker with raylet."); -} - -ray::Status RayletClient::SubmitTask(const ray::TaskSpecification &task_spec) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateSubmitTaskRequest( - fbb, fbb.CreateString(task_spec.Serialize())); - fbb.Finish(message); - return conn_->WriteMessage(MessageType::SubmitTask, &fbb); -} - -ray::Status RayletClient::GetTask(std::unique_ptr *task_spec) { - std::unique_ptr reply; - // Receive a task from the raylet. This will block until the raylet - // gives this client a task. - auto status = - conn_->AtomicRequestReply(MessageType::GetTask, MessageType::ExecuteTask, reply); - if (!status.ok()) return status; - // Parse the flatbuffer object. - auto reply_message = flatbuffers::GetRoot(reply.get()); - // Set the resource IDs for this task. - resource_ids_.clear(); - for (size_t i = 0; i < reply_message->fractional_resource_ids()->size(); ++i) { - auto const &fractional_resource_ids = - reply_message->fractional_resource_ids()->Get(i); - auto &acquired_resources = - resource_ids_[string_from_flatbuf(*fractional_resource_ids->resource_name())]; - - size_t num_resource_ids = fractional_resource_ids->resource_ids()->size(); - size_t num_resource_fractions = fractional_resource_ids->resource_fractions()->size(); - RAY_CHECK(num_resource_ids == num_resource_fractions); - RAY_CHECK(num_resource_ids > 0); - for (size_t j = 0; j < num_resource_ids; ++j) { - int64_t resource_id = fractional_resource_ids->resource_ids()->Get(j); - double resource_fraction = fractional_resource_ids->resource_fractions()->Get(j); - if (num_resource_ids > 1) { - int64_t whole_fraction = resource_fraction; - RAY_CHECK(whole_fraction == resource_fraction); - } - acquired_resources.push_back(std::make_pair(resource_id, resource_fraction)); - } - } - - // Return the copy of the task spec and pass ownership to the caller. - task_spec->reset( - new ray::TaskSpecification(string_from_flatbuf(*reply_message->task_spec()))); - return ray::Status::OK(); -} - -ray::Status RayletClient::TaskDone() { - return conn_->WriteMessage(MessageType::TaskDone); -} - -ray::Status RayletClient::FetchOrReconstruct(const std::vector &object_ids, - bool fetch_only, - const TaskID ¤t_task_id) { - flatbuffers::FlatBufferBuilder fbb; - auto object_ids_message = to_flatbuf(fbb, object_ids); - auto message = ray::protocol::CreateFetchOrReconstruct( - fbb, object_ids_message, fetch_only, to_flatbuf(fbb, current_task_id)); - fbb.Finish(message); - auto status = conn_->WriteMessage(MessageType::FetchOrReconstruct, &fbb); - return status; -} - -ray::Status RayletClient::NotifyUnblocked(const TaskID ¤t_task_id) { - flatbuffers::FlatBufferBuilder fbb; - auto message = - ray::protocol::CreateNotifyUnblocked(fbb, to_flatbuf(fbb, current_task_id)); - fbb.Finish(message); - return conn_->WriteMessage(MessageType::NotifyUnblocked, &fbb); -} - -ray::Status RayletClient::Wait(const std::vector &object_ids, int num_returns, - int64_t timeout_milliseconds, bool wait_local, - const TaskID ¤t_task_id, WaitResultPair *result) { - // Write request. - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateWaitRequest( - fbb, to_flatbuf(fbb, object_ids), num_returns, timeout_milliseconds, wait_local, - to_flatbuf(fbb, current_task_id)); - fbb.Finish(message); - std::unique_ptr reply; - auto status = conn_->AtomicRequestReply(MessageType::WaitRequest, - MessageType::WaitReply, reply, &fbb); - if (!status.ok()) return status; - // Parse the flatbuffer object. - auto reply_message = flatbuffers::GetRoot(reply.get()); - auto found = reply_message->found(); - for (uint i = 0; i < found->size(); i++) { - ObjectID object_id = ObjectID::FromBinary(found->Get(i)->str()); - result->first.push_back(object_id); - } - auto remaining = reply_message->remaining(); - for (uint i = 0; i < remaining->size(); i++) { - ObjectID object_id = ObjectID::FromBinary(remaining->Get(i)->str()); - result->second.push_back(object_id); - } - return ray::Status::OK(); -} - -ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string &type, - const std::string &error_message, double timestamp) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreatePushErrorRequest( - fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type), - fbb.CreateString(error_message), timestamp); - fbb.Finish(message); - - return conn_->WriteMessage(MessageType::PushErrorRequest, &fbb); -} - -ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_events) { - flatbuffers::FlatBufferBuilder fbb; - auto message = fbb.CreateString(profile_events.SerializeAsString()); - fbb.Finish(message); - - auto status = conn_->WriteMessage(MessageType::PushProfileEventsRequest, &fbb); - // Don't be too strict for profile errors. Just create logs and prevent it from crash. - if (!status.ok()) { - RAY_LOG(ERROR) << status.ToString() - << " [RayletClient] Failed to push profile events."; - } - return ray::Status::OK(); -} - -ray::Status RayletClient::FreeObjects(const std::vector &object_ids, - bool local_only, bool delete_creating_tasks) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateFreeObjectsRequest( - fbb, local_only, delete_creating_tasks, to_flatbuf(fbb, object_ids)); - fbb.Finish(message); - - auto status = conn_->WriteMessage(MessageType::FreeObjectsInObjectStoreRequest, &fbb); - return status; -} - -ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id, - ActorCheckpointID &checkpoint_id) { - flatbuffers::FlatBufferBuilder fbb; - auto message = - ray::protocol::CreatePrepareActorCheckpointRequest(fbb, to_flatbuf(fbb, actor_id)); - fbb.Finish(message); - - std::unique_ptr reply; - auto status = - conn_->AtomicRequestReply(MessageType::PrepareActorCheckpointRequest, - MessageType::PrepareActorCheckpointReply, reply, &fbb); - if (!status.ok()) return status; - auto reply_message = - flatbuffers::GetRoot(reply.get()); - checkpoint_id = ActorCheckpointID::FromBinary(reply_message->checkpoint_id()->str()); - return ray::Status::OK(); -} - -ray::Status RayletClient::NotifyActorResumedFromCheckpoint( - const ActorID &actor_id, const ActorCheckpointID &checkpoint_id) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateNotifyActorResumedFromCheckpoint( - fbb, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, checkpoint_id)); - fbb.Finish(message); - - return conn_->WriteMessage(MessageType::NotifyActorResumedFromCheckpoint, &fbb); -} - -ray::Status RayletClient::SetResource(const std::string &resource_name, - const double capacity, - const ray::ClientID &client_Id) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateSetResourceRequest( - fbb, fbb.CreateString(resource_name), capacity, to_flatbuf(fbb, client_Id)); - fbb.Finish(message); - return conn_->WriteMessage(MessageType::SetResourceRequest, &fbb); -} diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 397eb81aa..b67a14818 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -5,7 +5,6 @@ #include -#include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/reconstruction_policy.h" #include "ray/object_manager/object_directory.h" @@ -400,6 +399,7 @@ TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { auto task_reconstruction_data = std::make_shared(); task_reconstruction_data->set_node_manager_id(ClientID::FromRandom().Binary()); task_reconstruction_data->set_num_reconstructions(0); + RAY_CHECK_OK( mock_gcs_.AppendAt(JobID::Nil(), task_id, task_reconstruction_data, nullptr, /*failure_callback=*/ diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index 1084d356c..aebcc90b2 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -340,7 +340,7 @@ void TaskDependencyManager::AcquireTaskLease(const TaskID &task_id) { // Check that we were able to renew the task lease before the previous one // expired. if (now_ms > it->second.expires_at) { - RAY_LOG(WARNING) << "Task lease to renew has already expired by " + RAY_LOG(WARNING) << "Task " << task_id << " lease to renew has already expired by " << (it->second.expires_at - now_ms) << "ms"; } diff --git a/src/ray/raylet/task_dependency_manager.h b/src/ray/raylet/task_dependency_manager.h index 7effa44ed..dc4e5d4ab 100644 --- a/src/ray/raylet/task_dependency_manager.h +++ b/src/ray/raylet/task_dependency_manager.h @@ -3,6 +3,7 @@ // clang-format off #include "ray/common/id.h" +#include "ray/common/ray_config.h" #include "ray/common/task/task.h" #include "ray/object_manager/object_manager.h" #include "ray/raylet/reconstruction_policy.h" diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 820dfdee2..cf92b21dd 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -2,7 +2,6 @@ #include -#include "ray/raylet/format/node_manager_generated.h" #include "ray/raylet/raylet.h" namespace ray { @@ -11,15 +10,14 @@ namespace raylet { /// A constructor responsible for initializing the state of a worker. Worker::Worker(const WorkerID &worker_id, pid_t pid, const Language &language, int port, - std::shared_ptr connection, rpc::ClientCallManager &client_call_manager) : worker_id_(worker_id), pid_(pid), - language_(language), port_(port), - connection_(connection), - dead_(false), + language_(language), blocked_(false), + num_missed_heartbeats_(0), + is_being_killed_(false), client_call_manager_(client_call_manager) { if (port_ > 0) { rpc_client_ = std::unique_ptr( @@ -27,9 +25,9 @@ Worker::Worker(const WorkerID &worker_id, pid_t pid, const Language &language, i } } -void Worker::MarkDead() { dead_ = true; } +void Worker::MarkAsBeingKilled() { is_being_killed_ = true; } -bool Worker::IsDead() const { return dead_; } +bool Worker::IsBeingKilled() const { return is_being_killed_; } void Worker::MarkBlocked() { blocked_ = true; } @@ -43,6 +41,8 @@ pid_t Worker::Pid() const { return pid_; } Language Worker::GetLanguage() const { return language_; } +const WorkerID &Worker::GetWorkerId() const { return worker_id_; } + int Worker::Port() const { return port_; } void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id; } @@ -76,10 +76,6 @@ void Worker::AssignActorId(const ActorID &actor_id) { const ActorID &Worker::GetActorId() const { return actor_id_; } -const std::shared_ptr Worker::Connection() const { - return connection_; -} - const ResourceIdSet &Worker::GetLifetimeResourceIds() const { return lifetime_resource_ids_; } @@ -113,10 +109,16 @@ void Worker::AcquireTaskCpuResources(const ResourceIdSet &cpu_resources) { task_resource_ids_.Release(cpu_resources); } +void Worker::SetGetTaskReplyAndCallback( + rpc::GetTaskReply *reply, const rpc::SendReplyCallback &&send_reply_callback) { + RAY_CHECK(reply_ == nullptr && send_reply_callback_ == nullptr); + reply_ = reply; + send_reply_callback_ = std::move(send_reply_callback); +} + bool Worker::UsePush() const { return rpc_client_ != nullptr; } -void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set, - const std::function finish_assign_callback) { +void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set) { const TaskSpecification &spec = task.GetTaskSpecification(); if (rpc_client_ != nullptr) { // Use push mode. @@ -126,29 +128,32 @@ void Worker::AssignTask(const Task &task, const ResourceIdSet &resource_id_set, task.GetTaskSpecification().GetMessage()); request.mutable_task()->mutable_task_execution_spec()->CopyFrom( task.GetTaskExecutionSpec().GetMessage()); - request.set_resource_ids(resource_id_set.Serialize()); - + for (const auto &e : resource_id_set.ToProtobuf()) { + auto resource = request.add_resource_ids(); + *resource = e; + } auto status = rpc_client_->AssignTask( request, [](Status status, const rpc::AssignTaskReply &reply) { // Worker has finished this task. There's nothing to do here // and assigning new task will be done when raylet receives // `TaskDone` message. }); - finish_assign_callback(status); } else { // Use pull mode. This corresponds to existing python/java workers that haven't been // migrated to core worker architecture. - flatbuffers::FlatBufferBuilder fbb; - auto resource_id_set_flatbuf = resource_id_set.ToFlatbuf(fbb); - - auto message = - protocol::CreateGetTaskReply(fbb, fbb.CreateString(spec.Serialize()), - fbb.CreateVector(resource_id_set_flatbuf)); - fbb.Finish(message); - Connection()->WriteMessageAsync( - static_cast(protocol::MessageType::ExecuteTask), fbb.GetSize(), - fbb.GetBufferPointer(), finish_assign_callback); + RAY_CHECK(reply_ != nullptr && send_reply_callback_ != nullptr); + reply_->set_task_spec(task.GetTaskSpecification().Serialize()); + for (const auto &e : resource_id_set.ToProtobuf()) { + auto resource = reply_->add_fractional_resource_ids(); + *resource = e; + } + send_reply_callback_(Status::OK(), nullptr, nullptr); + reply_ = nullptr; + send_reply_callback_ = nullptr; } + // The status will be cleared when the worker dies. + AssignTaskId(spec.TaskId()); + AssignJobId(spec.JobId()); } } // namespace raylet diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index aa86a1224..d46209c9f 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -3,12 +3,15 @@ #include -#include "ray/common/client_connection.h" #include "ray/common/id.h" #include "ray/common/task/scheduling_resources.h" #include "ray/common/task/task.h" #include "ray/common/task/task_common.h" +#include "ray/protobuf/common.pb.h" #include "ray/rpc/worker/worker_client.h" +#include "src/ray/protobuf/gcs.pb.h" +#include "src/ray/protobuf/raylet.pb.h" +#include "src/ray/rpc/server_call.h" namespace ray { @@ -21,12 +24,11 @@ class Worker { public: /// A constructor that initializes a worker object. Worker(const WorkerID &worker_id, pid_t pid, const Language &language, int port, - std::shared_ptr connection, rpc::ClientCallManager &client_call_manager); /// A destructor responsible for freeing all worker state. ~Worker() {} - void MarkDead(); - bool IsDead() const; + void MarkAsBeingKilled(); + bool IsBeingKilled() const; void MarkBlocked(); void MarkUnblocked(); bool IsBlocked() const; @@ -35,6 +37,7 @@ class Worker { /// Return the worker's PID. pid_t Pid() const; Language GetLanguage() const; + const WorkerID &GetWorkerId() const; int Port() const; void AssignTaskId(const TaskID &task_id); const TaskID &GetAssignedTaskId() const; @@ -45,8 +48,6 @@ class Worker { const JobID &GetAssignedJobId() const; void AssignActorId(const ActorID &actor_id); const ActorID &GetActorId() const; - /// Return the worker's connection. - const std::shared_ptr Connection() const; const ResourceIdSet &GetLifetimeResourceIds() const; void SetLifetimeResourceIds(ResourceIdSet &resource_ids); @@ -58,30 +59,34 @@ class Worker { ResourceIdSet ReleaseTaskCpuResources(); void AcquireTaskCpuResources(const ResourceIdSet &cpu_resources); + int TickHeartbeatTimer() { return ++num_missed_heartbeats_; } + void ClearHeartbeat() { num_missed_heartbeats_ = 0; } + + /// When receiving a `GetTask` request from worker, this function should be called to + /// pass the reply and callback of the request to the worker. Later, when we actually + /// assign a task to the worker, the reply will be filled and the callback will be + /// called. + void SetGetTaskReplyAndCallback(rpc::GetTaskReply *reply, + const rpc::SendReplyCallback &&send_reply_callback); + bool UsePush() const; - void AssignTask(const Task &task, const ResourceIdSet &resource_id_set, - const std::function finish_assign_callback); + void AssignTask(const Task &task, const ResourceIdSet &resource_id_set); private: /// The worker's ID. WorkerID worker_id_; /// The worker's PID. pid_t pid_; + /// The worker port. + int port_; /// The language type of this worker. Language language_; - /// Port that this worker listens on. - /// If port <= 0, this indicates that the worker will not listen to a port. - int port_; - /// Connection state of a worker. - std::shared_ptr connection_; /// The worker's currently assigned task. TaskID assigned_task_id_; /// Job ID for the worker's current assigned task. JobID assigned_job_id_; /// The worker's actor ID. If this is nil, then the worker is not an actor. ActorID actor_id_; - /// Whether the worker is dead. - bool dead_; /// Whether the worker is blocked. Workers become blocked in a `ray.get`, if /// they require a data dependency while executing a task. bool blocked_; @@ -92,11 +97,20 @@ class Worker { // of a task. ResourceIdSet task_resource_ids_; std::unordered_set blocked_task_ids_; + /// How many heartbeats have been missed for this worker. + int num_missed_heartbeats_; + /// Indicates we have sent kill signal to the worker if it's true. We cannot treat the + /// worker process as really dead until we lost the heartbeats from the worker. + bool is_being_killed_; /// The `ClientCallManager` object that is shared by `WorkerTaskClient` from all /// workers. rpc::ClientCallManager &client_call_manager_; /// The rpc client to send tasks to this worker. std::unique_ptr rpc_client_; + /// Reply of the `GetTask` request. + rpc::GetTaskReply *reply_ = nullptr; + /// Callback of the `GetTask` request. + rpc::SendReplyCallback send_reply_callback_ = nullptr; }; } // namespace raylet diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index a02014020..6b5422a3e 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -12,29 +12,6 @@ #include "ray/util/logging.h" #include "ray/util/util.h" -namespace { - -// A helper function to get a worker from a list. -std::shared_ptr GetWorker( - const std::unordered_set> &worker_pool, - const std::shared_ptr &connection) { - for (auto it = worker_pool.begin(); it != worker_pool.end(); it++) { - if ((*it)->Connection() == connection) { - return (*it); - } - } - return nullptr; -} - -// A helper function to remove a worker from a list. Returns true if the worker -// was found and removed. -bool RemoveWorker(std::unordered_set> &worker_pool, - const std::shared_ptr &worker) { - return worker_pool.erase(worker) > 0; -} - -} // namespace - namespace ray { namespace raylet { @@ -73,8 +50,8 @@ WorkerPool::~WorkerPool() { for (const auto &entry : states_by_lang_) { // Kill all registered workers. NOTE(swang): This assumes that the registered // workers were started by the pool. - for (const auto &worker : entry.second.registered_workers) { - pids_to_kill.insert(worker->Pid()); + for (const auto &worker_pair : entry.second.registered_workers) { + pids_to_kill.insert(worker_pair.second->Pid()); } // Kill all the workers that have been started but not registered. for (const auto &starting_worker : entry.second.starting_worker_processes) { @@ -189,12 +166,13 @@ pid_t WorkerPool::StartProcess(const std::vector &worker_command_ar return 0; } -void WorkerPool::RegisterWorker(const std::shared_ptr &worker) { +void WorkerPool::RegisterWorker(const WorkerID &worker_id, + const std::shared_ptr &worker) { const auto pid = worker->Pid(); const auto port = worker->Port(); RAY_LOG(DEBUG) << "Registering worker with pid " << pid << ", port: " << port; auto &state = GetStateForLanguage(worker->GetLanguage()); - state.registered_workers.insert(std::move(worker)); + state.registered_workers.emplace(worker_id, std::move(worker)); auto it = state.starting_worker_processes.find(pid); if (it == state.starting_worker_processes.end()) { @@ -207,29 +185,30 @@ void WorkerPool::RegisterWorker(const std::shared_ptr &worker) { } } -void WorkerPool::RegisterDriver(const std::shared_ptr &driver) { +void WorkerPool::RegisterDriver(const WorkerID &driver_id, + const std::shared_ptr &driver) { RAY_CHECK(!driver->GetAssignedTaskId().IsNil()); auto &state = GetStateForLanguage(driver->GetLanguage()); - state.registered_drivers.insert(std::move(driver)); + state.registered_drivers.emplace(driver_id, std::move(driver)); } -std::shared_ptr WorkerPool::GetRegisteredWorker( - const std::shared_ptr &connection) const { +std::shared_ptr WorkerPool::GetRegisteredWorker(const WorkerID &worker_id) const { for (const auto &entry : states_by_lang_) { - auto worker = GetWorker(entry.second.registered_workers, connection); - if (worker != nullptr) { - return worker; + auto ®istered_workers = entry.second.registered_workers; + auto it = registered_workers.find(worker_id); + if (it != registered_workers.end()) { + return it->second; } } return nullptr; } -std::shared_ptr WorkerPool::GetRegisteredDriver( - const std::shared_ptr &connection) const { +std::shared_ptr WorkerPool::GetRegisteredDriver(const WorkerID &worker_id) const { for (const auto &entry : states_by_lang_) { - auto driver = GetWorker(entry.second.registered_drivers, connection); - if (driver != nullptr) { - return driver; + auto ®istered_drivers = entry.second.registered_drivers; + auto it = registered_drivers.find(worker_id); + if (it != registered_drivers.end()) { + return it->second; } } return nullptr; @@ -313,18 +292,20 @@ std::shared_ptr WorkerPool::PopWorker(const TaskSpecification &task_spec bool WorkerPool::DisconnectWorker(const std::shared_ptr &worker) { auto &state = GetStateForLanguage(worker->GetLanguage()); - RAY_CHECK(RemoveWorker(state.registered_workers, worker)); + RAY_CHECK(state.registered_workers.erase(worker->GetWorkerId())); stats::CurrentWorker().Record( 0, {{stats::LanguageKey, Language_Name(worker->GetLanguage())}, {stats::WorkerPidKey, std::to_string(worker->Pid())}}); - return RemoveWorker(state.idle, worker); + // Indicates that we disconnect a idle worker successfully. + return (state.idle.erase(worker) > 0); } void WorkerPool::DisconnectDriver(const std::shared_ptr &driver) { auto &state = GetStateForLanguage(driver->GetLanguage()); - RAY_CHECK(RemoveWorker(state.registered_drivers, driver)); + RAY_CHECK(state.registered_drivers.erase(driver->GetWorkerId())); + stats::CurrentDriver().Record( 0, {{stats::LanguageKey, Language_Name(driver->GetLanguage())}, {stats::WorkerPidKey, std::to_string(driver->Pid())}}); @@ -341,7 +322,8 @@ std::vector> WorkerPool::GetWorkersRunningTasksForJob( std::vector> workers; for (const auto &entry : states_by_lang_) { - for (const auto &worker : entry.second.registered_workers) { + for (const auto &worker_pair : entry.second.registered_workers) { + auto &worker = worker_pair.second; if (worker->GetAssignedJobId() == job_id) { workers.push_back(worker); } @@ -396,14 +378,16 @@ std::string WorkerPool::DebugString() const { void WorkerPool::RecordMetrics() const { for (const auto &entry : states_by_lang_) { // Record worker. - for (auto worker : entry.second.registered_workers) { + for (auto worker_pair : entry.second.registered_workers) { + auto &worker = worker_pair.second; stats::CurrentWorker().Record( worker->Pid(), {{stats::LanguageKey, Language_Name(worker->GetLanguage())}, {stats::WorkerPidKey, std::to_string(worker->Pid())}}); } // Record driver. - for (auto driver : entry.second.registered_drivers) { + for (auto driver_pair : entry.second.registered_drivers) { + auto &driver = driver_pair.second; stats::CurrentDriver().Record( driver->Pid(), {{stats::LanguageKey, Language_Name(driver->GetLanguage())}, {stats::WorkerPidKey, std::to_string(driver->Pid())}}); @@ -411,6 +395,27 @@ void WorkerPool::RecordMetrics() const { } } +void WorkerPool::TickHeartbeatTimer(int max_missed_heartbeats, + std::vector> *dead_workers) { + for (const auto &entry : states_by_lang_) { + // Worker heartbeat. + for (const auto &worker_pair : entry.second.registered_workers) { + auto &worker = worker_pair.second; + if (worker->TickHeartbeatTimer() >= max_missed_heartbeats) { + dead_workers->emplace_back(worker); + } + } + + // Driver heartbeat. + for (const auto &driver_pair : entry.second.registered_drivers) { + auto &driver = driver_pair.second; + if (driver->TickHeartbeatTimer() >= max_missed_heartbeats) { + dead_workers->emplace_back(driver); + } + } + } +} + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 9569cd5c2..fc42f0545 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -51,28 +51,27 @@ class WorkerPool { /// pool after it becomes idle (e.g., requests a work assignment). /// /// \param The Worker to be registered. - void RegisterWorker(const std::shared_ptr &worker); + void RegisterWorker(const WorkerID &worker_id, const std::shared_ptr &worker); /// Register a new driver. + /// Driver is a treated as a special worker, so use WorkerID as key here. /// /// \param The driver to be registered. - void RegisterDriver(const std::shared_ptr &worker); + void RegisterDriver(const WorkerID &driver_id, const std::shared_ptr &worker); /// Get the client connection's registered worker. /// /// \param The client connection owned by a registered worker. /// \return The Worker that owns the given client connection. Returns nullptr /// if the client has not registered a worker yet. - std::shared_ptr GetRegisteredWorker( - const std::shared_ptr &connection) const; + std::shared_ptr GetRegisteredWorker(const WorkerID &worker_id) const; /// Get the client connection's registered driver. /// /// \param The client connection owned by a registered driver. /// \return The Worker that owns the given client connection. Returns nullptr /// if the client has not registered a driver. - std::shared_ptr GetRegisteredDriver( - const std::shared_ptr &connection) const; + std::shared_ptr GetRegisteredDriver(const WorkerID &driver_id) const; /// Disconnect a registered worker. /// @@ -129,6 +128,16 @@ class WorkerPool { /// Record metrics. void RecordMetrics() const; + /// Tick the heartbeat timer and get the workers that have timed out. + /// A worker which has missed `max_missed_heartbeats` times would be treated as a + /// dead process or the network to it has been down. + /// + /// \param[in] max_missed_heartbeats The maximum number of heartbeats that can be + /// missed before a worker times out. + /// \param[out] dead_workers Workers whose processes have been dead. + void TickHeartbeatTimer(int max_missed_heartbeats, + std::vector> *dead_workers); + protected: /// Asynchronously start a new worker process. Once the worker process has /// registered with an external server, the process should create and @@ -166,9 +175,9 @@ class WorkerPool { std::unordered_map> idle_actor; /// All workers that have registered and are still connected, including both /// idle and executing. - std::unordered_set> registered_workers; + std::unordered_map> registered_workers; /// All drivers that have registered and are still connected. - std::unordered_set> registered_drivers; + std::unordered_map> registered_drivers; /// A map from the pids of starting worker processes /// to the number of their unregistered workers. std::unordered_map starting_worker_processes; diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 05cbfaab2..c6c00257f 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -71,19 +71,9 @@ class WorkerPoolTest : public ::testing::Test { std::shared_ptr CreateWorker(pid_t pid, const Language &language = Language::PYTHON) { - std::function client_handler = - [this](LocalClientConnection &client) { HandleNewClient(client); }; - std::function, int64_t, const uint8_t *)> - message_handler = [this](std::shared_ptr client, - int64_t message_type, const uint8_t *message) { - HandleMessage(client, message_type, message); - }; - boost::asio::local::stream_protocol::socket socket(io_service_); - auto client = - LocalClientConnection::Create(client_handler, message_handler, std::move(socket), - "worker", {}, error_message_type_); - return std::shared_ptr(new Worker(WorkerID::FromRandom(), pid, language, -1, - client, client_call_manager_)); + WorkerID worker_id = WorkerID::FromRandom(); + return std::shared_ptr(new Worker( + worker_id, pid, language, /* listening port */ -1, client_call_manager_)); } void SetWorkerCommands(const WorkerCommandMap &worker_commands) { @@ -96,10 +86,6 @@ class WorkerPoolTest : public ::testing::Test { boost::asio::io_service io_service_; int64_t error_message_type_; rpc::ClientCallManager client_call_manager_; - - private: - void HandleNewClient(LocalClientConnection &){}; - void HandleMessage(std::shared_ptr, int64_t, const uint8_t *){}; }; static inline TaskSpecification ExampleTaskSpec( @@ -131,21 +117,23 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) { workers.push_back(CreateWorker(pid)); } for (const auto &worker : workers) { + WorkerID worker_id = worker->GetWorkerId(); // Check that there's still a starting worker process // before all workers have been registered ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), 1); // Check that we cannot lookup the worker before it's registered. - ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), nullptr); - worker_pool_.RegisterWorker(worker); + ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker_id), nullptr); + worker_pool_.RegisterWorker(worker_id, worker); // Check that we can lookup the worker after it's registered. - ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), worker); + ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker_id), worker); } // Check that there's no starting worker process ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), 0); for (const auto &worker : workers) { + WorkerID worker_id = worker->GetWorkerId(); worker_pool_.DisconnectWorker(worker); // Check that we cannot lookup the worker after it's disconnected. - ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), nullptr); + ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker_id), nullptr); } } diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h index 25f8b3ac7..27935f489 100644 --- a/src/ray/rpc/client_call.h +++ b/src/ray/rpc/client_call.h @@ -126,11 +126,14 @@ class ClientCallManager { explicit ClientCallManager(boost::asio::io_service &main_service) : main_service_(main_service) { // Start the polling thread. - std::thread polling_thread(&ClientCallManager::PollEventsFromCompletionQueue, this); - polling_thread.detach(); + polling_thread_ = + std::thread(&ClientCallManager::PollEventsFromCompletionQueue, this); } - ~ClientCallManager() { cq_.Shutdown(); } + ~ClientCallManager() { + cq_.Shutdown(); + polling_thread_.join(); + } /// Create a new `ClientCall` and send request. /// @@ -177,7 +180,7 @@ class ClientCallManager { // Keep reading events from the `CompletionQueue` until it's shutdown. while (cq_.Next(&got_tag, &ok)) { auto tag = reinterpret_cast(got_tag); - if (ok) { + if (ok && !main_service_.stopped()) { // Post the callback to the main event loop. main_service_.post([tag]() { tag->GetCall()->OnReplyReceived(); @@ -195,6 +198,9 @@ class ClientCallManager { /// The gRPC `CompletionQueue` object used to poll events. grpc::CompletionQueue cq_; + + /// Polling thread to check the completion queue. + std::thread polling_thread_; }; } // namespace rpc diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 8e4120869..be7fe90ea 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -6,14 +6,20 @@ namespace ray { namespace rpc { void GrpcServer::Run() { - std::string server_address("0.0.0.0:" + std::to_string(port_)); + std::string server_address; + // Set unix domain socket or tcp address. + if (!unix_socket_path_.empty()) { + server_address = "unix://" + unix_socket_path_; + } else { + server_address = "0.0.0.0:" + std::to_string(port_); + } grpc::ServerBuilder builder; // TODO(hchen): Add options for authentication. builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); // Register all the services to this server. if (services_.empty()) { - RAY_LOG(WARNING) << "No service is found when start grpc server " << name_; + RAY_LOG(WARNING) << "No service found when start grpc server " << name_; } for (auto &entry : services_) { builder.RegisterService(&entry.get()); @@ -23,7 +29,11 @@ void GrpcServer::Run() { cq_ = builder.AddCompletionQueue(); // Build and start server. server_ = builder.BuildAndStart(); - RAY_LOG(INFO) << name_ << " server started, listening on port " << port_ << "."; + if (unix_socket_path_.empty()) { + // For a TCP-based server, the actual port is decided after `AddListeningPort`. + server_address = "0.0.0.0:" + std::to_string(port_); + } + RAY_LOG(INFO) << name_ << " server started, listening on " << server_address; // Create calls for all the server call factories. for (auto &entry : server_call_factories_and_concurrencies_) { diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h index 13fd5c02b..8f4c975df 100644 --- a/src/ray/rpc/grpc_server.h +++ b/src/ray/rpc/grpc_server.h @@ -26,15 +26,22 @@ class GrpcService; /// which kinds of requests this server should accept. class GrpcServer { public: - /// Constructor. + /// Construct a gRPC server that listens on a TCP port. /// /// \param[in] name Name of this server, used for logging and debugging purpose. /// \param[in] port The port to bind this server to. If it's 0, a random available port /// will be chosen. - /// \param[in] main_service The main event loop, to which service handler functions - /// will be posted. GrpcServer(std::string name, const uint32_t port) - : name_(std::move(name)), port_(port), is_closed_(true) {} + : name_(std::move(name)), port_(port), is_closed_(false) {} + + /// Construct a gRPC server that listens on unix domain socket. + /// + /// \param[in] name Name of this server, used for logging and debugging purpose. + /// \param[in] unix_socket_path Unix domain socket full path. + GrpcServer(std::string name, const std::string &unix_socket_path) + : GrpcServer(std::move(name), 0) { + unix_socket_path_ = unix_socket_path; + } /// Destruct this gRPC server. ~GrpcServer() { Shutdown(); } @@ -73,6 +80,10 @@ class GrpcServer { const std::string name_; /// Port of this server. int port_; + /// Indicates whether this server has been closed. + bool is_closed_; + /// Unix domain socket path. + std::string unix_socket_path_; /// The `grpc::Service` objects which should be registered to `ServerBuilder`. std::vector> services_; /// The `ServerCallFactory` objects, and the maximum number of concurrent requests that @@ -83,10 +94,8 @@ class GrpcServer { std::unique_ptr cq_; /// The `Server` object. std::unique_ptr server_; - /// The polling thread used to check the completion queue + /// The polling thread used to check the completion queue. std::thread polling_thread_; - /// Flag indicates whether this server has closed - bool is_closed_; }; /// Base class that represents an abstract gRPC service. diff --git a/src/ray/rpc/node_manager/node_manager_client.h b/src/ray/rpc/node_manager/node_manager_client.h index 005c75db4..5d7c915bf 100644 --- a/src/ray/rpc/node_manager/node_manager_client.h +++ b/src/ray/rpc/node_manager/node_manager_client.h @@ -5,6 +5,7 @@ #include +#include "ray/common/grpc_util.h" #include "ray/common/status.h" #include "ray/rpc/client_call.h" #include "ray/util/logging.h" diff --git a/src/ray/rpc/object_manager/object_manager_client.h b/src/ray/rpc/object_manager/object_manager_client.h index f37a081e6..9c9921503 100644 --- a/src/ray/rpc/object_manager/object_manager_client.h +++ b/src/ray/rpc/object_manager/object_manager_client.h @@ -5,6 +5,7 @@ #include +#include "ray/common/grpc_util.h" #include "ray/common/status.h" #include "ray/util/logging.h" #include "src/ray/protobuf/object_manager.grpc.pb.h" diff --git a/src/ray/rpc/raylet/raylet_client.cc b/src/ray/rpc/raylet/raylet_client.cc new file mode 100644 index 000000000..c6cfc2f05 --- /dev/null +++ b/src/ray/rpc/raylet/raylet_client.cc @@ -0,0 +1,411 @@ + +#include "src/ray/rpc/raylet/raylet_client.h" + +namespace ray { +namespace rpc { + +#define RETURN_IF_DISCONNECTED(connected) \ + do { \ + if (!(connected)) { \ + return Status::Invalid("Raylet connection is closed."); \ + } \ + } while (0) + +RayletClient::RayletClient(const std::string &raylet_socket, const WorkerID &worker_id, + bool is_worker, const JobID &job_id, const Language &language, + int port) + : worker_id_(worker_id), + is_worker_(is_worker), + job_id_(job_id), + language_(language), + port_(port), + main_service_(), + work_(main_service_), + client_call_manager_(main_service_), + heartbeat_timer_(main_service_), + is_connected_(false) { + std::shared_ptr channel = + grpc::CreateChannel("unix://" + raylet_socket, grpc::InsecureChannelCredentials()); + stub_ = RayletService::NewStub(channel); + + rpc_thread_ = std::thread([this]() { main_service_.run(); }); + RAY_LOG(DEBUG) << "Connecting to unix socket: " + << "unix://" + raylet_socket + << ", is worker: " << (is_worker_ ? "true" : "false") + << ", worker id: " << worker_id; + // Try to register client `num_raylet_client_retry_times` times. + TryRegisterClient(RayConfig::instance().num_raylet_client_retry_times()); +} + +void RayletClient::TryRegisterClient(int retry_times) { + // We should block here until register succeeds. + for (int i = 0; i < retry_times; i++) { + auto st = RegisterClient(); + if (st.ok()) { + is_connected_ = true; + Heartbeat(); + return; + } + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + RAY_LOG(FATAL) << "Failed to register to raylet server, worker id: " << worker_id_ + << ", pid: " << static_cast(getpid()) + << ", is worker: " << is_worker_; +} + +RayletClient::~RayletClient() { + is_connected_ = false; + main_service_.stop(); + rpc_thread_.join(); +} + +ray::Status RayletClient::Disconnect() { + DisconnectClientRequest disconnect_client_request; + disconnect_client_request.set_worker_id(worker_id_.Binary()); + + DisconnectClientReply reply; + grpc::ClientContext context; + auto status = stub_->DisconnectClient(&context, disconnect_client_request, &reply); + if (!status.ok()) { + RAY_LOG(ERROR) << "Failed to disconnect from raylet, msg: " << status.error_message(); + } + return GrpcStatusToRayStatus(status); +} + +ray::Status RayletClient::SubmitTask(const ray::TaskSpecification &task_spec) { + RETURN_IF_DISCONNECTED(is_connected_); + SubmitTaskRequest submit_task_request; + submit_task_request.mutable_task_spec()->CopyFrom(task_spec.GetMessage()); + + grpc::ClientContext context; + SubmitTaskReply reply; + auto status = stub_->SubmitTask(&context, submit_task_request, &reply); + return GrpcStatusToRayStatus(status); +} + +ray::Status RayletClient::GetTask(std::unique_ptr *task_spec) { + RETURN_IF_DISCONNECTED(is_connected_); + GetTaskRequest get_task_request; + get_task_request.set_worker_id(worker_id_.Binary()); + + grpc::ClientContext context; + GetTaskReply reply; + // The actual RPC. + auto status = stub_->GetTask(&context, get_task_request, &reply); + + if (status.ok()) { + resource_ids_.clear(); + // Parse resources that would be used by this assigned task. + for (size_t i = 0; i < reply.fractional_resource_ids().size(); ++i) { + auto const &fractional_resource_ids = reply.fractional_resource_ids()[i]; + auto &acquired_resources = resource_ids_[fractional_resource_ids.resource_name()]; + + // Each resource includes a series of resource IDs (e.g., GPU 0) and corresponding + // amount for that resource ID. If the resource amount is fractional, then there + // should only be one resource ID. + size_t num_resource_ids = fractional_resource_ids.resource_ids().size(); + size_t num_resource_fractions = fractional_resource_ids.resource_fractions().size(); + RAY_CHECK(num_resource_ids == num_resource_fractions); + RAY_CHECK(num_resource_ids > 0); + for (size_t j = 0; j < num_resource_ids; ++j) { + int64_t resource_id = fractional_resource_ids.resource_ids()[j]; + double resource_fraction = fractional_resource_ids.resource_fractions()[j]; + if (num_resource_ids > 1) { + int64_t whole_fraction = resource_fraction; + RAY_CHECK(whole_fraction == resource_fraction); + } + acquired_resources.emplace_back(resource_id, resource_fraction); + } + } + task_spec->reset(new ray::TaskSpecification(reply.task_spec())); + } else { + *task_spec = nullptr; + RAY_LOG(INFO) << "Failed to get task, msg: " << status.error_message(); + } + return GrpcStatusToRayStatus(status); +} + +ray::Status RayletClient::TaskDone() { + RETURN_IF_DISCONNECTED(is_connected_); + TaskDoneRequest task_done_request; + task_done_request.set_worker_id(worker_id_.Binary()); + + auto callback = [this](const Status &status, const TaskDoneReply &reply) { + if (!status.ok() && is_connected_) { + is_connected_ = false; + RAY_LOG(INFO) << "Failed to send TaskDoneRequest, msg: " << status.message(); + } + }; + + auto call = + client_call_manager_.CreateCall( + *stub_, &RayletService::Stub::PrepareAsyncTaskDone, task_done_request, + callback); + return call->GetStatus(); +} + +ray::Status RayletClient::FetchOrReconstruct(const std::vector &object_ids, + bool fetch_only, + const TaskID ¤t_task_id) { + RETURN_IF_DISCONNECTED(is_connected_); + FetchOrReconstructRequest fetch_or_reconstruct_request; + fetch_or_reconstruct_request.set_fetch_only(fetch_only); + fetch_or_reconstruct_request.set_task_id(current_task_id.Binary()); + fetch_or_reconstruct_request.set_worker_id(worker_id_.Binary()); + IdVectorToProtobuf( + object_ids, fetch_or_reconstruct_request, + &FetchOrReconstructRequest::add_object_ids); + + // Callback to deal with reply. + auto callback = [this](const Status &status, const FetchOrReconstructReply &reply) { + if (!status.ok() && is_connected_) { + is_connected_ = false; + RAY_LOG(INFO) << "Failed to send FetchOrReconstructRequest, msg: " + << status.message(); + } + }; + + auto call = + client_call_manager_ + .CreateCall( + *stub_, &RayletService::Stub::PrepareAsyncFetchOrReconstruct, + fetch_or_reconstruct_request, callback); + return call->GetStatus(); +} + +ray::Status RayletClient::NotifyUnblocked(const TaskID ¤t_task_id) { + RETURN_IF_DISCONNECTED(is_connected_); + NotifyUnblockedRequest notify_unblocked_request; + notify_unblocked_request.set_worker_id(worker_id_.Binary()); + notify_unblocked_request.set_task_id(current_task_id.Binary()); + + auto callback = [this](const Status &status, const NotifyUnblockedReply &reply) { + if (!status.ok() && is_connected_) { + is_connected_ = false; + RAY_LOG(INFO) << "Failed to send NotifyUnblockedRequest, msg: " << status.message(); + } + }; + + auto call = + client_call_manager_ + .CreateCall( + *stub_, &RayletService::Stub::PrepareAsyncNotifyUnblocked, + notify_unblocked_request, callback); + return call->GetStatus(); +} + +ray::Status RayletClient::Wait(const std::vector &object_ids, int num_returns, + int64_t timeout_milliseconds, bool wait_local, + const TaskID ¤t_task_id, WaitResultPair *result) { + RETURN_IF_DISCONNECTED(is_connected_); + WaitRequest wait_request; + wait_request.set_worker_id(worker_id_.Binary()); + wait_request.set_timeout(timeout_milliseconds); + wait_request.set_wait_local(wait_local); + wait_request.set_task_id(current_task_id.Binary()); + wait_request.set_num_ready_objects(num_returns); + IdVectorToProtobuf(object_ids, wait_request, + &WaitRequest::add_object_ids); + + grpc::ClientContext context; + WaitReply reply; + auto status = stub_->Wait(&context, wait_request, &reply); + + if (status.ok()) { + result->first = IdVectorFromProtobuf(reply.found()); + result->second = IdVectorFromProtobuf(reply.remaining()); + } else { + RAY_LOG(INFO) << "Failed to send WaitRequest, msg: " << status.error_message(); + } + + return GrpcStatusToRayStatus(status); +} + +ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string &type, + const std::string &error_message, double timestamp) { + RETURN_IF_DISCONNECTED(is_connected_); + PushErrorRequest push_error_request; + push_error_request.set_job_id(job_id.Binary()); + push_error_request.set_type(type); + push_error_request.set_error_message(error_message); + push_error_request.set_timestamp(timestamp); + + auto callback = [this](const Status &status, const PushErrorReply &reply) { + if (!status.ok() && is_connected_) { + is_connected_ = false; + RAY_LOG(INFO) << "Failed to send PushErrorRequest, msg: " << status.message(); + } + }; + + auto call = + client_call_manager_.CreateCall( + *stub_, &RayletService::Stub::PrepareAsyncPushError, push_error_request, + callback); + return call->GetStatus(); +} + +ray::Status RayletClient::PushProfileEvents(const ProfileTableData &profile_events) { + RETURN_IF_DISCONNECTED(is_connected_); + PushProfileEventsRequest push_profile_events_request; + push_profile_events_request.mutable_profile_table_data()->CopyFrom(profile_events); + + auto callback = [this](const Status &status, const PushProfileEventsReply &reply) { + if (!status.ok() && is_connected_) { + is_connected_ = false; + RAY_LOG(INFO) << "Failed to send PushProfileEventsRequest, msg: " + << status.message(); + } + }; + + auto call = + client_call_manager_ + .CreateCall( + *stub_, &RayletService::Stub::PrepareAsyncPushProfileEvents, + push_profile_events_request, callback); + return call->GetStatus(); +} + +ray::Status RayletClient::FreeObjects(const std::vector &object_ids, + bool local_only, bool delete_creating_tasks) { + RETURN_IF_DISCONNECTED(is_connected_); + FreeObjectsInStoreRequest free_objects_request; + free_objects_request.set_local_only(local_only); + free_objects_request.set_delete_creating_tasks(delete_creating_tasks); + IdVectorToProtobuf( + object_ids, free_objects_request, &FreeObjectsInStoreRequest::add_object_ids); + + auto callback = [this](const Status &status, const FreeObjectsInStoreReply &reply) { + if (!status.ok() && is_connected_) { + is_connected_ = false; + RAY_LOG(INFO) << "Failed to send FreeObjectsInStoreRequest, msg: " + << status.message(); + } + }; + + auto call = + client_call_manager_ + .CreateCall( + *stub_, &RayletService::Stub::PrepareAsyncFreeObjectsInStore, + free_objects_request, callback); + return call->GetStatus(); +} + +ray::Status RayletClient::PrepareActorCheckpoint(const ActorID &actor_id, + ActorCheckpointID &checkpoint_id) { + RETURN_IF_DISCONNECTED(is_connected_); + PrepareActorCheckpointRequest prepare_actor_checkpoint_request; + prepare_actor_checkpoint_request.set_actor_id(actor_id.Binary()); + prepare_actor_checkpoint_request.set_worker_id(worker_id_.Binary()); + + grpc::ClientContext context; + PrepareActorCheckpointReply reply; + auto status = + stub_->PrepareActorCheckpoint(&context, prepare_actor_checkpoint_request, &reply); + + if (status.ok()) { + checkpoint_id = ActorCheckpointID::FromBinary(reply.checkpoint_id()); + } else { + RAY_LOG(INFO) << "Failed to send PrepareActorCheckpointRequest, msg: " + << status.error_message(); + } + + return GrpcStatusToRayStatus(status); +} + +ray::Status RayletClient::NotifyActorResumedFromCheckpoint( + const ActorID &actor_id, const ActorCheckpointID &checkpoint_id) { + RETURN_IF_DISCONNECTED(is_connected_); + NotifyActorResumedFromCheckpointRequest notify_actor_resumed_from_checkpoint_request; + notify_actor_resumed_from_checkpoint_request.set_actor_id(actor_id.Binary()); + notify_actor_resumed_from_checkpoint_request.set_checkpoint_id(checkpoint_id.Binary()); + + auto callback = [this](const Status &status, + const NotifyActorResumedFromCheckpointReply &reply) { + if (!status.ok() && is_connected_) { + is_connected_ = false; + RAY_LOG(INFO) << "NotifyActorResumedFromCheckpoint failed, msg: " + << status.message(); + } + }; + + auto call = + client_call_manager_ + .CreateCall( + *stub_, &RayletService::Stub::PrepareAsyncNotifyActorResumedFromCheckpoint, + notify_actor_resumed_from_checkpoint_request, callback); + return call->GetStatus(); +} + +ray::Status RayletClient::SetResource(const std::string &resource_name, + const double capacity, + const ray::ClientID &client_id) { + RETURN_IF_DISCONNECTED(is_connected_); + SetResourceRequest set_resource_request; + set_resource_request.set_resource_name(resource_name); + set_resource_request.set_capacity(capacity); + set_resource_request.set_client_id(client_id.Binary()); + + auto callback = [this](const Status &status, const SetResourceReply &reply) { + if (!status.ok() && is_connected_) { + is_connected_ = false; + RAY_LOG(INFO) << "SetResource failed, msg: " << status.message(); + } + }; + + auto call = client_call_manager_ + .CreateCall( + *stub_, &RayletService::Stub::PrepareAsyncSetResource, + set_resource_request, callback); + return call->GetStatus(); +} + +ray::Status RayletClient::RegisterClient() { + RegisterClientRequest register_client_request; + register_client_request.set_is_worker(is_worker_); + register_client_request.set_worker_id(worker_id_.Binary()); + register_client_request.set_worker_pid(getpid()); + register_client_request.set_job_id(job_id_.Binary()); + register_client_request.set_language(language_); + register_client_request.set_port(port_); + + grpc::ClientContext context; + RegisterClientReply reply; + auto status = stub_->RegisterClient(&context, register_client_request, &reply); + + if (!status.ok()) { + RAY_LOG(DEBUG) << "Failed to register client, msg: " << status.error_message(); + } + + return GrpcStatusToRayStatus(status); +} + +void RayletClient::Heartbeat() { + if (!is_connected_) { + return; + } + HeartbeatRequest heartbeat_request; + heartbeat_request.set_is_worker(is_worker_); + heartbeat_request.set_worker_id(worker_id_.Binary()); + + auto callback = [this](const Status &status, const HeartbeatReply &reply) { + if (!status.ok() && is_connected_) { + is_connected_ = false; + RAY_LOG(INFO) << "Heartbeat failed, msg: " << status.message(); + } + }; + auto call = + client_call_manager_.CreateCall( + *stub_, &RayletService::Stub::PrepareAsyncHeartbeat, heartbeat_request, + callback); + + heartbeat_timer_.expires_from_now(boost::posix_time::milliseconds( + RayConfig::instance().heartbeat_timeout_milliseconds())); + heartbeat_timer_.async_wait([this](const boost::system::error_code &error) { + RAY_CHECK(!error); + Heartbeat(); + }); +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/raylet/raylet_client.h b/src/ray/rpc/raylet/raylet_client.h similarity index 68% rename from src/ray/raylet/raylet_client.h rename to src/ray/rpc/raylet/raylet_client.h index c27b325df..6960b0f8e 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/rpc/raylet/raylet_client.h @@ -1,84 +1,66 @@ -#ifndef RAYLET_CLIENT_H -#define RAYLET_CLIENT_H +#ifndef RAY_RPC_RAYLET_CLIENT_H +#define RAY_RPC_RAYLET_CLIENT_H +#include #include #include +#include #include +#include #include #include +#include + #include "ray/common/status.h" #include "ray/common/task/task_spec.h" +#include "src/ray/common/status.h" +#include "src/ray/protobuf/raylet.grpc.pb.h" +#include "src/ray/protobuf/raylet.pb.h" +#include "src/ray/rpc/client_call.h" using ray::ActorCheckpointID; using ray::ActorID; -using ray::ClientID; using ray::JobID; using ray::ObjectID; using ray::TaskID; -using ray::UniqueID; +using ray::WorkerID; using ray::Language; using ray::rpc::ProfileTableData; - -using MessageType = ray::protocol::MessageType; -using ResourceMappingType = - std::unordered_map>>; using WaitResultPair = std::pair, std::vector>; -class RayletConnection { - public: - /// Connect to the raylet. - /// - /// \param raylet_socket The name of the socket to use to connect to the raylet. - /// \param worker_id A unique ID to represent the worker. - /// \param is_worker Whether this client is a worker. If it is a worker, an - /// additional message will be sent to register as one. - /// \param job_id The ID of the driver. This is non-nil if the client is a - /// driver. - /// \return The connection information. - RayletConnection(const std::string &raylet_socket, int num_retries, int64_t timeout); +namespace ray { +namespace rpc { - ~RayletConnection() { close(conn_); } - /// Notify the raylet that this client is disconnecting gracefully. This - /// is used by actors to exit gracefully so that the raylet doesn't - /// propagate an error message to the driver. - /// - /// \return ray::Status. - ray::Status Disconnect(); - ray::Status ReadMessage(MessageType type, std::unique_ptr &message); - ray::Status WriteMessage(MessageType type, - flatbuffers::FlatBufferBuilder *fbb = nullptr); - ray::Status AtomicRequestReply(MessageType request_type, MessageType reply_type, - std::unique_ptr &reply_message, - flatbuffers::FlatBufferBuilder *fbb = nullptr); - - private: - /// File descriptor of the Unix domain socket that connects to raylet. - int conn_; - /// A mutex to protect stateful operations of the raylet client. - std::mutex mutex_; - /// A mutex to protect write operations of the raylet client. - std::mutex write_mutex_; -}; +using ResourceMappingType = + std::unordered_map>>; +/// Client used for communicating with the raylet. class RayletClient { public: - /// Connect to the raylet. + /// Constructor for the raylet client. + /// TODO(jzh): At present, client call manager and reply handler service are generated + /// in raylet client. Instead, we should add parameters to the constructor and pass them + /// in. Change them as input parameters once we changed the worker into a server. /// - /// \param raylet_socket The name of the socket to use to connect to the raylet. - /// \param worker_id A unique ID to represent the worker. - /// \param is_worker Whether this client is a worker. If it is a worker, an - /// additional message will be sent to register as one. - /// \param job_id The ID of the driver. This is non-nil if the client is a driver. - /// \return The connection information. - RayletClient(const std::string &raylet_socket, const ClientID &client_id, + /// \param[in] raylet_socket Unix domain socket of the raylet server. + /// \param[in] worker_id The worker id. + /// \param[in] is_worker Indicates whether a worker or a driver. + /// \param[in] job_id The job id that this raylet client belongs to. + /// \param[in] language The language type, python or java. + /// \param[in] port The listening port of the worker server, -1 means that the worker + /// does not have a server. + RayletClient(const std::string &raylet_socket, const WorkerID &worker_id, bool is_worker, const JobID &job_id, const Language &language, int port = -1); - ray::Status Disconnect() { return conn_->Disconnect(); }; + ~RayletClient(); - /// Submit a task using the raylet code path. + /// Send disconnect request to local raylet. + ray::Status Disconnect(); + + /// Submit a task to the local raylet. /// /// \param The task specification. /// \return ray::Status. @@ -178,7 +160,7 @@ class RayletClient { Language GetLanguage() const { return language_; } - ClientID GetClientID() const { return client_id_; } + WorkerID GetWorkerId() const { return worker_id_; } JobID GetJobID() const { return job_id_; } @@ -187,7 +169,21 @@ class RayletClient { const ResourceMappingType &GetResourceIDs() const { return resource_ids_; } private: - const ClientID client_id_; + /// Try to register client in raylet, we would retry serveral time to + /// reconnect if failed. We need this because raylet client may start before raylet + /// server. + /// + /// \param times Number of times to retry. + void TryRegisterClient(int retry_times); + + ray::Status RegisterClient(); + + /// Send heartbeat requests to the raylet server. + void Heartbeat(); + /// Id of the worker to which this raylet client belongs. + const WorkerID worker_id_; + /// Indicates whether this worker is a driver worker. + /// Driver is treated as a special worker. const bool is_worker_; const JobID job_id_; const Language language_; @@ -196,8 +192,30 @@ class RayletClient { /// for this worker. Each pair consists of the resource ID and the fraction /// of that resource allocated for this worker. ResourceMappingType resource_ids_; - /// The connection to the raylet server. - std::unique_ptr conn_; + + /// The gRPC-generated stub. + std::unique_ptr stub_; + + /// Service for handling reply. + boost::asio::io_service main_service_; + + /// Asio work for main service. + boost::asio::io_service::work work_; + + /// The `ClientCallManager` used for managing requests. + ClientCallManager client_call_manager_; + + /// The thread used to handle reply. + std::thread rpc_thread_; + + /// Heartbeat timer. + boost::asio::deadline_timer heartbeat_timer_; + + /// Indicates whether the connection has been closed. + bool is_connected_; }; -#endif +} // namespace rpc +} // namespace ray + +#endif // RAY_RPC_RAYLET_CLIENT_H diff --git a/src/ray/rpc/raylet/raylet_server.h b/src/ray/rpc/raylet/raylet_server.h new file mode 100644 index 000000000..608cd26ff --- /dev/null +++ b/src/ray/rpc/raylet/raylet_server.h @@ -0,0 +1,258 @@ +#ifndef RAY_RPC_RAYLET_SERVER_H +#define RAY_RPC_RAYLET_SERVER_H + +#include "src/ray/rpc/grpc_server.h" +#include "src/ray/rpc/server_call.h" + +#include "src/ray/protobuf/raylet.grpc.pb.h" +#include "src/ray/protobuf/raylet.pb.h" + +namespace ray { +namespace rpc { + +/// Implementations of the `RayletService`, check interface in +/// `src/ray/protobuf/raylet.proto`. +class RayletServiceHandler { + public: + virtual void HandleRegisterClientRequest(const RegisterClientRequest &request, + RegisterClientReply *reply, + SendReplyCallback send_reply_callback) = 0; + /// Handle a `SubmitTask` request. + /// The implementation can handle this request asynchronously. When handling is done, + /// the `send_reply_callback` should be called. + /// + /// \param[in] request The request message. + /// \param[out] reply The reply message. + /// \param[in] send_reply_callback The callback to be called when the request is done. + virtual void HandleSubmitTaskRequest(const SubmitTaskRequest &request, + SubmitTaskReply *reply, + SendReplyCallback send_reply_callback) = 0; + /// Handle a `DisconnectClient` request. + virtual void HandleDisconnectClientRequest(const DisconnectClientRequest &request, + DisconnectClientReply *reply, + SendReplyCallback send_reply_callback) = 0; + /// Handle a `GetTask` request. + virtual void HandleGetTaskRequest(const GetTaskRequest &request, GetTaskReply *reply, + SendReplyCallback send_reply_callback) = 0; + /// Handle a `TaskDone` request. + virtual void HandleTaskDoneRequest(const TaskDoneRequest &request, TaskDoneReply *reply, + SendReplyCallback send_reply_callback) = 0; + /// Handle a `HandleFetchOrReconstruct` request. + virtual void HandleFetchOrReconstructRequest(const FetchOrReconstructRequest &request, + FetchOrReconstructReply *reply, + SendReplyCallback send_reply_callback) = 0; + /// Handle a `HandleNotifyUnblocked` request. + virtual void HandleNotifyUnblockedRequest(const NotifyUnblockedRequest &request, + NotifyUnblockedReply *reply, + SendReplyCallback send_reply_callback) = 0; + /// Handle a `Wait` request. + virtual void HandleWaitRequest(const WaitRequest &request, WaitReply *reply, + SendReplyCallback send_reply_callback) = 0; + /// Handle a `PushError` request. + virtual void HandlePushErrorRequest(const PushErrorRequest &request, + PushErrorReply *reply, + SendReplyCallback send_reply_callback) = 0; + /// Handle a `PushProfileEvents` request. + virtual void HandlePushProfileEventsRequest(const PushProfileEventsRequest &request, + PushProfileEventsReply *reply, + SendReplyCallback send_reply_callback) = 0; + /// Handle a `FreeObjectsInStoreInObjectStore` request. + virtual void HandleFreeObjectsInStoreRequest(const FreeObjectsInStoreRequest &request, + FreeObjectsInStoreReply *reply, + SendReplyCallback send_reply_callback) = 0; + /// Handle a `PrepareActorCheckpoint` request. + virtual void HandlePrepareActorCheckpointRequest( + const PrepareActorCheckpointRequest &request, PrepareActorCheckpointReply *reply, + SendReplyCallback send_reply_callback) = 0; + /// Handle a `NotifyActorResumedFromCheckpoint` request. + virtual void HandleNotifyActorResumedFromCheckpointRequest( + const NotifyActorResumedFromCheckpointRequest &request, + NotifyActorResumedFromCheckpointReply *reply, + SendReplyCallback send_reply_callback) = 0; + /// Handle a `SetResource` request. + virtual void HandleSetResourceRequest(const SetResourceRequest &request, + SetResourceReply *reply, + SendReplyCallback send_reply_callback) = 0; + /// Handle a `SetResourceReply` request. + virtual void HandleHeartbeatRequest(const HeartbeatRequest &request, + HeartbeatReply *reply, + SendReplyCallback send_reply_callback) = 0; +}; + +/// The `GrpcService` for `RayletGrpcService`. +class RayletGrpcService : public GrpcService { + public: + /// Construct a `RayletGrpcService`. + /// + /// \param[in] io_service Service used to handle incoming requests + /// \param[in] handler The service handler that actually handle the requests. + RayletGrpcService(boost::asio::io_service &io_service, + RayletServiceHandler &service_handler) + : GrpcService(io_service), service_handler_(service_handler){}; + + protected: + grpc::Service &GetGrpcService() override { return service_; } + + void InitServerCallFactories( + const std::unique_ptr &cq, + std::vector, int>> + *server_call_factories_and_concurrencies) override { + // Initialize the factory for `RegisterClient` requests. + std::unique_ptr register_client_call_factory( + new ServerCallFactoryImpl( + service_, &RayletService::AsyncService::RequestRegisterClient, + service_handler_, &RayletServiceHandler::HandleRegisterClientRequest, cq, + main_service_)); + server_call_factories_and_concurrencies->emplace_back( + std::move(register_client_call_factory), 10); + + // Initialize the factory for `SubmitTask` requests. + std::unique_ptr submit_task_call_factory( + new ServerCallFactoryImpl( + service_, &RayletService::AsyncService::RequestSubmitTask, service_handler_, + &RayletServiceHandler::HandleSubmitTaskRequest, cq, main_service_)); + server_call_factories_and_concurrencies->emplace_back( + std::move(submit_task_call_factory), 20); + + // Initialize the factory for `DisconnectClient` requests. + std::unique_ptr disconnect_client_call_factory( + new ServerCallFactoryImpl( + service_, &RayletService::AsyncService::RequestDisconnectClient, + service_handler_, &RayletServiceHandler::HandleDisconnectClientRequest, cq, + main_service_)); + server_call_factories_and_concurrencies->emplace_back( + std::move(disconnect_client_call_factory), 10); + + // Initialize the factory for `GetTask` requests. + std::unique_ptr get_task_call_factory( + new ServerCallFactoryImpl( + service_, &RayletService::AsyncService::RequestGetTask, service_handler_, + &RayletServiceHandler::HandleGetTaskRequest, cq, main_service_)); + server_call_factories_and_concurrencies->emplace_back( + std::move(get_task_call_factory), 20); + + // Initialize the factory for `TaskDone` requests. + std::unique_ptr task_done_call_factory( + new ServerCallFactoryImpl( + service_, &RayletService::AsyncService::RequestTaskDone, service_handler_, + &RayletServiceHandler::HandleTaskDoneRequest, cq, main_service_)); + server_call_factories_and_concurrencies->emplace_back( + std::move(task_done_call_factory), 20); + + // Initialize the factory for `FetchOrReconstruct` requests. + std::unique_ptr fetch_or_reconstruct_call_factory( + new ServerCallFactoryImpl( + service_, &RayletService::AsyncService::RequestFetchOrReconstruct, + service_handler_, &RayletServiceHandler::HandleFetchOrReconstructRequest, cq, + main_service_)); + server_call_factories_and_concurrencies->emplace_back( + std::move(fetch_or_reconstruct_call_factory), 10); + + // Initialize the factory for `NotifyUnblocked` requests. + std::unique_ptr notify_unblocked_call_factory( + new ServerCallFactoryImpl( + service_, &RayletService::AsyncService::RequestNotifyUnblocked, + service_handler_, &RayletServiceHandler::HandleNotifyUnblockedRequest, cq, + main_service_)); + server_call_factories_and_concurrencies->emplace_back( + std::move(notify_unblocked_call_factory), 10); + + // Initialize the factory for `Wait` requests. + std::unique_ptr wait_call_factory( + new ServerCallFactoryImpl( + service_, &RayletService::AsyncService::RequestWait, service_handler_, + &RayletServiceHandler::HandleWaitRequest, cq, main_service_)); + server_call_factories_and_concurrencies->emplace_back(std::move(wait_call_factory), + 20); + + // Initialize the factory for `PushError` requests. + std::unique_ptr push_error_call_factory( + new ServerCallFactoryImpl( + service_, &RayletService::AsyncService::RequestPushError, service_handler_, + &RayletServiceHandler::HandlePushErrorRequest, cq, main_service_)); + server_call_factories_and_concurrencies->emplace_back( + std::move(push_error_call_factory), 10); + + // Initialize the factory for `PushProfileEvents` requests. + std::unique_ptr push_profile_events_call_factory( + new ServerCallFactoryImpl( + service_, &RayletService::AsyncService::RequestPushProfileEvents, + service_handler_, &RayletServiceHandler::HandlePushProfileEventsRequest, cq, + main_service_)); + server_call_factories_and_concurrencies->emplace_back( + std::move(push_profile_events_call_factory), 10); + + // Initialize the factory for `FreeObjectsInStore` requests. + std::unique_ptr free_objects_call_factory( + new ServerCallFactoryImpl( + service_, &RayletService::AsyncService::RequestFreeObjectsInStore, + service_handler_, &RayletServiceHandler::HandleFreeObjectsInStoreRequest, cq, + main_service_)); + server_call_factories_and_concurrencies->emplace_back( + std::move(free_objects_call_factory), 10); + + // Initialize the factory for `PrepareActorCheckpoint` requests. + std::unique_ptr prepare_actor_checkpoint_call_factory( + new ServerCallFactoryImpl( + service_, &RayletService::AsyncService::RequestPrepareActorCheckpoint, + service_handler_, &RayletServiceHandler::HandlePrepareActorCheckpointRequest, + cq, main_service_)); + server_call_factories_and_concurrencies->emplace_back( + std::move(prepare_actor_checkpoint_call_factory), 10); + + // Initialize the factory for `NotifyActorResumedFromCheckpoint` requests. + std::unique_ptr notify_actor_resumed_from_checkpoint_call_factory( + new ServerCallFactoryImpl( + service_, + &RayletService::AsyncService::RequestNotifyActorResumedFromCheckpoint, + service_handler_, + &RayletServiceHandler::HandleNotifyActorResumedFromCheckpointRequest, cq, + main_service_)); + server_call_factories_and_concurrencies->emplace_back( + std::move(notify_actor_resumed_from_checkpoint_call_factory), 10); + + // Initialize the factory for `SetResource` requests. + std::unique_ptr set_resource_call_factory( + new ServerCallFactoryImpl( + service_, &RayletService::AsyncService::RequestSetResource, service_handler_, + &RayletServiceHandler::HandleSetResourceRequest, cq, main_service_)); + server_call_factories_and_concurrencies->emplace_back( + std::move(set_resource_call_factory), 10); + + // Initialize the factory for `Heartbeat` requests. + std::unique_ptr heartbeat_call_factory( + new ServerCallFactoryImpl( + service_, &RayletService::AsyncService::RequestHeartbeat, service_handler_, + &RayletServiceHandler::HandleHeartbeatRequest, cq, main_service_)); + server_call_factories_and_concurrencies->emplace_back( + std::move(heartbeat_call_factory), 10); + } + + private: + /// The grpc async service object. + RayletService::AsyncService service_; + /// The service handler that actually handle the requests. + RayletServiceHandler &service_handler_; +}; + +} // namespace rpc +} // namespace ray + +#endif diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h index 33fbbce39..750ecc8f6 100644 --- a/src/ray/rpc/server_call.h +++ b/src/ray/rpc/server_call.h @@ -2,6 +2,7 @@ #define RAY_RPC_SERVER_CALL_H #include +#include #include "ray/common/grpc_util.h" #include "ray/common/status.h"