diff --git a/.gitignore b/.gitignore index 4001d29be..642aca1b0 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ /bazel-* /python/ray/core /python/ray/pyarrow_files/ +/python/ray/pickle5_files/ /python/build /python/dist /thirdparty/pkg/ diff --git a/BUILD.bazel b/BUILD.bazel index 2ac2f7f22..7e87dd1ce 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -404,6 +404,16 @@ cc_binary( ], ) +cc_test( + name = "direct_task_transport_test", + srcs = ["src/ray/core_worker/test/direct_task_transport_test.cc"], + copts = COPTS, + deps = [ + ":core_worker_lib", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "reference_count_test", srcs = ["src/ray/core_worker/reference_count_test.cc"], diff --git a/ci/travis/format.sh b/ci/travis/format.sh index 8e43dba2b..616aa0d1a 100755 --- a/ci/travis/format.sh +++ b/ci/travis/format.sh @@ -70,29 +70,29 @@ format_changed() { # could cause yapf to receive 0 positional arguments, making it hang # waiting for STDIN. # - # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that + # `diff-filter=ACRM` and $MERGEBASE is to ensure we only format files that # exist on both branches. MERGEBASE="$(git merge-base upstream/master HEAD)" - if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' &>/dev/null; then - git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' | xargs -P 5 \ + if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.py' &>/dev/null; then + git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \ yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}" if which flake8 >/dev/null; then - git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' | xargs -P 5 \ + git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \ flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605 fi fi - if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then + if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then if which flake8 >/dev/null; then - git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \ + git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \ flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605 fi fi if which clang-format >/dev/null; then - if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.cc' '*.h' &>/dev/null; then - git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.cc' '*.h' | xargs -P 5 \ + if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.cc' '*.h' &>/dev/null; then + git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.cc' '*.h' | xargs -P 5 \ clang-format -i fi fi diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 3da8de358..b1df36d56 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -93,6 +93,7 @@ from ray.ray_constants import ( DEFAULT_PUT_OBJECT_DELAY, DEFAULT_PUT_OBJECT_RETRIES, RAW_BUFFER_METADATA, + PICKLE_BUFFER_METADATA, PICKLE5_BUFFER_METADATA, ) @@ -334,6 +335,7 @@ cdef c_vector[c_string] string_vector_from_list(list string_list): cdef void prepare_args(list args, c_vector[CTaskArg] *args_vector): cdef: c_string pickled_str + c_string metadata_str = PICKLE_BUFFER_METADATA shared_ptr[CBuffer] arg_data shared_ptr[CBuffer] arg_metadata @@ -353,6 +355,11 @@ cdef void prepare_args(list args, c_vector[CTaskArg] *args_vector): (pickled_str.data()), pickled_str.size(), True)) + arg_metadata = dynamic_pointer_cast[ + CBuffer, LocalMemoryBuffer]( + make_shared[LocalMemoryBuffer]( + ( + metadata_str.data()), metadata_str.size(), True)) args_vector.push_back( CTaskArg.PassByValue( make_shared[CRayObject](arg_data, arg_metadata))) @@ -436,8 +443,18 @@ cdef deserialize_args( c_args[i].get().GetMetadata()).to_pybytes() == RAW_BUFFER_METADATA): args.append(data) - else: + elif (c_args[i].get().HasMetadata() and Buffer.make( + c_args[i].get().GetMetadata()).to_pybytes() + == PICKLE_BUFFER_METADATA): + # This is a pickled "simple python value" argument. args.append(pickle.loads(data.to_pybytes())) + else: + # This is a Ray object inlined by the direct task submitter. + by_reference_ids.append( + ObjectID(arg_reference_ids[i].Binary())) + by_reference_indices.append(i) + by_reference_objects.push_back(c_args[i]) + args.append(None) # Passed by reference. else: by_reference_ids.append( @@ -658,12 +675,14 @@ cdef shared_ptr[CBuffer] string_to_buffer(c_string& c_str): cdef shared_ptr[CBuffer] empty_metadata if c_str.size() == 0: return empty_metadata - return dynamic_pointer_cast[CBuffer, LocalMemoryBuffer]( - make_shared[LocalMemoryBuffer]((c_str.data()), - c_str.size(), True)) + return dynamic_pointer_cast[ + CBuffer, LocalMemoryBuffer]( + make_shared[LocalMemoryBuffer]( + (c_str.data()), c_str.size(), True)) -cdef write_serialized_object(serialized_object, const shared_ptr[CBuffer]& buf): +cdef write_serialized_object( + serialized_object, const shared_ptr[CBuffer]& buf): # avoid initializing pyarrow before raylet from ray.serialization import Pickle5SerializedObject, RawSerializedObject @@ -851,6 +870,7 @@ cdef class CoreWorker: function_descriptor, args, int num_return_vals, + c_bool is_direct_call, resources): cdef: unordered_map[c_string, double] c_resources @@ -861,7 +881,8 @@ cdef class CoreWorker: with self.profile_event(b"submit_task"): prepare_resources(resources, &c_resources) - task_options = CTaskOptions(num_return_vals, c_resources) + task_options = CTaskOptions( + num_return_vals, is_direct_call, c_resources) ray_function = CRayFunction( LANGUAGE_PYTHON, string_vector_from_list(function_descriptor)) prepare_args(args, &args_vector) @@ -925,7 +946,7 @@ cdef class CoreWorker: with self.profile_event(b"submit_task"): if num_method_cpus > 0: c_resources[b"CPU"] = num_method_cpus - task_options = CTaskOptions(num_return_vals, c_resources) + task_options = CTaskOptions(num_return_vals, False, c_resources) ray_function = CRayFunction( LANGUAGE_PYTHON, string_vector_from_list(function_descriptor)) prepare_args(args, &args_vector) @@ -1017,7 +1038,8 @@ cdef class CoreWorker: context = worker.get_serialization_context() serialized_object = context.serialize(output) data_sizes.push_back(serialized_object.total_bytes) - metadatas.push_back(string_to_buffer(serialized_object.metadata)) + metadatas.push_back( + string_to_buffer(serialized_object.metadata)) serialized_objects.append(serialized_object) check_status(self.core_worker.get().AllocateReturnObjects( @@ -1030,4 +1052,5 @@ cdef class CoreWorker: if serialized_object is NoReturn: returns[0][i].reset() else: - write_serialized_object(serialized_object, returns[0][i].get().GetData()) + write_serialized_object( + serialized_object, returns[0][i].get().GetData()) diff --git a/python/ray/dashboard/dashboard.py b/python/ray/dashboard/dashboard.py index 9aa2807fb..2ffa0b979 100644 --- a/python/ray/dashboard/dashboard.py +++ b/python/ray/dashboard/dashboard.py @@ -168,7 +168,9 @@ class Dashboard(object): raise ValueError( "Dashboard static asset directory not found at '{}'. If " "installing from source, please follow the additional steps " - "required to build the dashboard.".format(static_dir)) + "required to build the dashboard: " + "cd python/ray/dashboard/client && npm ci && " + "npm run build".format(static_dir)) self.app.router.add_static("/static", static_dir) self.app.router.add_get("/api/ray_config", ray_config) diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index b7187f87e..e16f944c1 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -201,7 +201,7 @@ cdef extern from "ray/core_worker/common.h" nogil: cdef cppclass CTaskOptions "ray::TaskOptions": CTaskOptions() - CTaskOptions(int num_returns, + CTaskOptions(int num_returns, c_bool is_direct_call, unordered_map[c_string, double] &resources) cdef cppclass CActorCreationOptions "ray::ActorCreationOptions": diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index 04f2207e2..c317e622b 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -148,7 +148,7 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil: c_bool is_put() - c_bool IsDirectActorType() + c_bool IsDirectCallType() int64_t ObjectIndex() const diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 65ce4755f..6ed41f57a 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -176,8 +176,8 @@ cdef class ObjectID(BaseID): def hex(self): return decode(self.data.Hex()) - def is_direct_actor_type(self): - return self.data.IsDirectActorType() + def is_direct_call_type(self): + return self.data.IsDirectCallType() def is_nil(self): return self.data.IsNil() diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index 8bad16755..4d54d3238 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -179,6 +179,10 @@ LOG_MONITOR_MAX_OPEN_FILES = 200 # A constant used as object metadata to indicate the object is raw binary. RAW_BUFFER_METADATA = b"RAW" +# A constant used as object metadata to indicate the object is pickled. This +# format is only ever used for Python inline task argument values. +PICKLE_BUFFER_METADATA = b"PICKLE" +# A constant used as object metadata to indicate the object is pickle5 format. PICKLE5_BUFFER_METADATA = b"PICKLE5" AUTOSCALER_RESOURCE_REQUEST_CHANNEL = b"autoscaler_resource_request" diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 69083f354..67d23901b 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -134,6 +134,7 @@ class RemoteFunction(object): args=None, kwargs=None, num_return_vals=None, + is_direct_call=None, num_cpus=None, num_gpus=None, memory=None, @@ -155,6 +156,8 @@ class RemoteFunction(object): if num_return_vals is None: num_return_vals = self._num_return_vals + if is_direct_call is None: + is_direct_call = False resources = ray.utils.resources_from_resource_arguments( self._num_cpus, self._num_gpus, self._memory, @@ -162,8 +165,11 @@ class RemoteFunction(object): memory, object_store_memory, resources) def invocation(args, kwargs): - list_args = ray.signature.flatten_args(self._function_signature, - args, kwargs) + if not args and not kwargs and not self._function_signature: + list_args = [] + else: + list_args = ray.signature.flatten_args( + self._function_signature, args, kwargs) if worker.mode == ray.worker.LOCAL_MODE: object_ids = worker.local_mode_manager.execute( @@ -172,7 +178,7 @@ class RemoteFunction(object): else: object_ids = worker.core_worker.submit_task( self._function_descriptor_list, list_args, num_return_vals, - resources) + is_direct_call, resources) if len(object_ids) == 1: return object_ids[0] diff --git a/python/ray/serialization.py b/python/ray/serialization.py index fa923c1d0..f6a6af8ac 100644 --- a/python/ray/serialization.py +++ b/python/ray/serialization.py @@ -157,8 +157,7 @@ class SerializationContext(object): serialization_context) def id_serializer(obj): - if isinstance(obj, - ray.ObjectID) and obj.is_direct_actor_type(): + if isinstance(obj, ray.ObjectID) and obj.is_direct_call_type(): raise NotImplementedError( "Objects produced by direct actor calls cannot be " "passed to other tasks as arguments.") @@ -191,8 +190,7 @@ class SerializationContext(object): custom_deserializer=actor_handle_deserializer) def id_serializer(obj): - if isinstance(obj, - ray.ObjectID) and obj.is_direct_actor_type(): + if isinstance(obj, ray.ObjectID) and obj.is_direct_call_type(): raise NotImplementedError( "Objects produced by direct actor calls cannot be " "passed to other tasks as arguments.") diff --git a/python/ray/signature.py b/python/ray/signature.py index b06445713..2e274d457 100644 --- a/python/ray/signature.py +++ b/python/ray/signature.py @@ -7,7 +7,6 @@ import funcsigs from funcsigs import Parameter import logging -import ray from ray.utils import is_cython # Logger for this module. It should be configured at the entry point @@ -136,12 +135,6 @@ def flatten_args(signature_parameters, args, kwargs): [None, 1, None, 2, None, 3, "a", 4] """ - for obj in args: - if isinstance(obj, ray.ObjectID) and obj.is_direct_actor_type(): - raise NotImplementedError( - "Objects produced by direct actor calls cannot be " - "passed to other tasks as arguments.") - restored = _restore_parameters(signature_parameters) reconstructed_signature = funcsigs.Signature(parameters=restored) try: diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index f92da9eb2..2c234ac65 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -1190,6 +1190,31 @@ def test_get_dict(ray_start_regular): assert result == expected +def test_direct_call_simple(ray_start_regular): + @ray.remote + def f(x): + return x + 1 + + f_direct = f.options(is_direct_call=True) + print("a") + assert ray.get(f_direct.remote(2)) == 3 + print("b") + assert ray.get([f_direct.remote(i) for i in range(100)]) == list( + range(1, 101)) + + +def test_direct_call_chain(ray_start_regular): + @ray.remote + def g(x): + return x + 1 + + g_direct = g.options(is_direct_call=True) + x = 0 + for _ in range(100): + x = g_direct.remote(x) + assert ray.get(x) == 100 + + def test_direct_actor_enabled(ray_start_regular): @ray.remote class Actor(object): @@ -1240,10 +1265,6 @@ def test_direct_actor_errors(ray_start_regular): a = Actor._remote(is_direct_call=True) - # cannot pass returns to other methods directly - with pytest.raises(Exception): - ray.get(f.remote(a.f.remote(2))) - # cannot pass returns to other methods even in a list with pytest.raises(Exception): ray.get(f.remote([a.f.remote(2)])) diff --git a/src/ray/common/grpc_util.h b/src/ray/common/grpc_util.h index 1931b3937..bc515046c 100644 --- a/src/ray/common/grpc_util.h +++ b/src/ray/common/grpc_util.h @@ -4,6 +4,7 @@ #include #include #include +#include #include "status.h" namespace ray { @@ -39,7 +40,7 @@ class MessageWrapper { const Message &GetMessage() const { return *message_; } /// Get reference of the protobuf message. - Message &GetMutableMessage() const { return *message_; } + Message &GetMutableMessage() { return *message_; } /// Serialize the message to a string. const std::string Serialize() const { return message_->SerializeAsString(); } @@ -64,7 +65,11 @@ inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) { if (grpc_status.ok()) { return Status::OK(); } else { - return Status::IOError(grpc_status.error_message()); + std::stringstream msg; + msg << grpc_status.error_code(); + msg << ": "; + msg << grpc_status.error_message(); + return Status::IOError(msg.str()); } } diff --git a/src/ray/common/id.h b/src/ray/common/id.h index 5a68b0b05..a847ea4fc 100644 --- a/src/ray/common/id.h +++ b/src/ray/common/id.h @@ -18,7 +18,7 @@ namespace ray { -enum class TaskTransportType { RAYLET, DIRECT_ACTOR }; +enum class TaskTransportType { RAYLET, DIRECT }; class TaskID; class WorkerID; @@ -290,8 +290,8 @@ class ObjectID : public BaseID { /// Return if this is a direct actor call object. /// /// \return True if this is a direct actor object return. - bool IsDirectActorType() const { - return GetTransportType() == static_cast(TaskTransportType::DIRECT_ACTOR); + bool IsDirectCallType() const { + return GetTransportType() == static_cast(TaskTransportType::DIRECT); } /// Return this object id with a changed transport type. diff --git a/src/ray/common/task/task.h b/src/ray/common/task/task.h index 4c37ebdd4..596cd49b5 100644 --- a/src/ray/common/task/task.h +++ b/src/ray/common/task/task.h @@ -9,6 +9,9 @@ namespace ray { +typedef std::function, const std::string &, int)> + DispatchTaskCallback; + /// \class Task /// /// A Task represents a Ray task and a specification of its execution (e.g., @@ -38,6 +41,13 @@ class Task { ComputeDependencies(); } + /// Override dispatch behaviour. + void OnDispatchInstead( + std::function, const std::string &, int)> + callback) { + on_dispatch_ = callback; + } + /// Get the mutable specification for the task. This specification may be /// updated at runtime. /// @@ -62,6 +72,9 @@ class Task { /// \param task Task structure with updated dynamic information. void CopyTaskExecutionSpec(const Task &task); + /// Returns the override dispatch task callback, or nullptr. + DispatchTaskCallback &OnDispatch() const { return on_dispatch_; } + std::string DebugString() const; private: @@ -78,6 +91,10 @@ class Task { /// the TaskSpecification and execution dependencies from the /// TaskExecutionSpecification. std::vector dependencies_; + + /// For direct task calls, overrides the dispatch behaviour to send an RPC + /// back to the submitting worker. + mutable DispatchTaskCallback on_dispatch_ = nullptr; }; } // namespace ray diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index 061900aaa..2738aa5b3 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -187,8 +187,11 @@ ObjectID TaskSpecification::ActorDummyObject() const { } bool TaskSpecification::IsDirectCall() const { - RAY_CHECK(IsActorCreationTask()); - return message_->actor_creation_task_spec().is_direct_call(); + if (IsActorCreationTask()) { + return message_->actor_creation_task_spec().is_direct_call(); + } else { + return message_->is_direct_call(); + } } int TaskSpecification::MaxActorConcurrency() const { diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index b4720f7ba..5c2b993a9 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -22,6 +22,8 @@ typedef std::pair SchedulingClassDescriptor; typedef int SchedulingClass; /// Wrapper class of protobuf `TaskSpec`, see `common.proto` for details. +/// TODO(ekl) we should consider passing around std::unique_ptrs +/// instead `const TaskSpecification`, since this class is actually mutable. class TaskSpecification : public MessageWrapper { public: /// Construct an empty task specification. This should not be used directly. diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index bbe855f6a..bd6170663 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -27,7 +27,7 @@ class TaskSpecBuilder { const TaskID &task_id, const Language &language, const std::vector &function_descriptor, const JobID &job_id, const TaskID &parent_task_id, uint64_t parent_counter, const TaskID &caller_id, - uint64_t num_returns, + uint64_t num_returns, bool is_direct_call, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources) { message_->set_type(TaskType::NORMAL_TASK); @@ -41,6 +41,7 @@ class TaskSpecBuilder { message_->set_parent_counter(parent_counter); message_->set_caller_id(caller_id.Binary()); message_->set_num_returns(num_returns); + message_->set_is_direct_call(is_direct_call); message_->mutable_required_resources()->insert(required_resources.begin(), required_resources.end()); message_->mutable_required_placement_resources()->insert( diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index a25f377eb..ebcc1f831 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -84,11 +84,14 @@ class TaskArg { /// Options for all tasks (actor and non-actor) except for actor creation. struct TaskOptions { TaskOptions() {} - TaskOptions(int num_returns, std::unordered_map &resources) - : num_returns(num_returns), resources(resources) {} + TaskOptions(int num_returns, bool is_direct_call, + std::unordered_map &resources) + : num_returns(num_returns), is_direct_call(is_direct_call), resources(resources) {} /// Number of returns of this task. int num_returns = 1; + /// Whether to use the direct task transport. + bool is_direct_call = false; /// Resources required by this task. std::unordered_map resources; }; diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index e2dc5fd98..634fd87a4 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -89,12 +89,13 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) { if (task_spec.IsNormalTask()) { RAY_CHECK(current_job_id_.IsNil()); SetCurrentJobId(task_spec.JobId()); + current_task_is_direct_call_ = task_spec.IsDirectCall(); } else if (task_spec.IsActorCreationTask()) { RAY_CHECK(current_job_id_.IsNil()); SetCurrentJobId(task_spec.JobId()); RAY_CHECK(current_actor_id_.IsNil()); current_actor_id_ = task_spec.ActorCreationId(); - current_actor_use_direct_call_ = task_spec.IsDirectCall(); + current_task_is_direct_call_ = task_spec.IsDirectCall(); current_actor_max_concurrency_ = task_spec.MaxActorConcurrency(); } else if (task_spec.IsActorTask()) { RAY_CHECK(current_job_id_ == task_spec.JobId()); @@ -117,8 +118,8 @@ std::shared_ptr WorkerContext::GetCurrentTask() const { const ActorID &WorkerContext::GetCurrentActorID() const { return current_actor_id_; } -bool WorkerContext::CurrentActorUseDirectCall() const { - return current_actor_use_direct_call_; +bool WorkerContext::CurrentTaskIsDirectCall() const { + return current_task_is_direct_call_; } int WorkerContext::CurrentActorMaxConcurrency() const { diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 531706cca..9a30e0123 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -34,7 +34,7 @@ class WorkerContext { const ActorID &GetCurrentActorID() const; - bool CurrentActorUseDirectCall() const; + bool CurrentTaskIsDirectCall() const; int CurrentActorMaxConcurrency() const; @@ -47,8 +47,8 @@ class WorkerContext { const WorkerID worker_id_; JobID current_job_id_; ActorID current_actor_id_; - bool current_actor_use_direct_call_ = false; - int current_actor_max_concurrency_; + bool current_task_is_direct_call_ = false; + int current_actor_max_concurrency_ = 1; private: static WorkerThreadContext &GetThreadContext(bool for_main_thread = false); diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index af4ed02e9..b2196c4f3 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -19,11 +19,16 @@ void BuildCommonTaskSpec( // Build common task spec. builder.SetCommonTaskSpec(task_id, function.GetLanguage(), function.GetFunctionDescriptor(), job_id, current_task_id, - task_index, caller_id, num_returns, required_resources, - required_placement_resources); + task_index, caller_id, num_returns, + transport_type == ray::TaskTransportType::DIRECT, + required_resources, required_placement_resources); // Set task arguments. for (const auto &arg : args) { if (arg.IsPassedByReference()) { + if (transport_type == ray::TaskTransportType::RAYLET) { + RAY_CHECK(!arg.GetReference().IsDirectCallType()) + << "NotImplemented: passing direct call objects to other tasks"; + } builder.AddByRefArg(arg.GetReference()); } else { builder.AddByValueArg(arg.GetValue()); @@ -44,7 +49,7 @@ void GroupObjectIdsByStoreProvider(const std::vector &object_ids, absl::flat_hash_set *plasma_object_ids, absl::flat_hash_set *memory_object_ids) { for (const auto &object_id : object_ids) { - if (object_id.IsDirectActorType()) { + if (object_id.IsDirectCallType()) { memory_object_ids->insert(object_id); } else { plasma_object_ids->insert(object_id); @@ -70,11 +75,12 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, check_signals_(check_signals), worker_context_(worker_type, job_id), io_work_(io_service_), + client_call_manager_(new rpc::ClientCallManager(io_service_)), heartbeat_timer_(io_service_), - worker_server_(WorkerTypeString(worker_type), 0 /* let grpc choose a port */), + core_worker_server_(WorkerTypeString(worker_type), 0 /* let grpc choose a port */), gcs_client_(gcs_options), - client_call_manager_(io_service_), memory_store_(std::make_shared()), + memory_store_provider_(memory_store_), task_execution_service_work_(task_execution_service_), task_execution_callback_(task_execution_callback), grpc_service_(io_service_, *this) { @@ -95,24 +101,21 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, profiler_ = std::make_shared(worker_context_, node_ip_address, io_service_, gcs_client_); - // Initialize task execution. + // Initialize task receivers. if (worker_type_ == WorkerType::WORKER) { RAY_CHECK(task_execution_callback_ != nullptr); - - // Initialize task receivers. auto execute_task = std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3); raylet_task_receiver_ = std::unique_ptr( new CoreWorkerRayletTaskReceiver(raylet_client_, execute_task, exit_handler)); - direct_actor_task_receiver_ = std::unique_ptr( - new CoreWorkerDirectActorTaskReceiver(worker_context_, task_execution_service_, - worker_server_, execute_task, - exit_handler)); - worker_server_.RegisterService(grpc_service_); + direct_task_receiver_ = + std::unique_ptr(new CoreWorkerDirectTaskReceiver( + worker_context_, task_execution_service_, execute_task, exit_handler)); } // Start RPC server after all the task receivers are properly initialized. - worker_server_.Run(); + core_worker_server_.RegisterService(grpc_service_); + core_worker_server_.Run(); // Initialize raylet client. // TODO(zhijunfu): currently RayletClient would crash in its constructor if it cannot @@ -120,15 +123,15 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, // so that the worker (java/python .etc) can retrieve and handle the error // instead of crashing. auto grpc_client = rpc::NodeManagerWorkerClient::make( - node_ip_address, node_manager_port, client_call_manager_); + node_ip_address, node_manager_port, *client_call_manager_); raylet_client_ = std::unique_ptr(new RayletClient( std::move(grpc_client), raylet_socket, WorkerID::FromBinary(worker_context_.GetWorkerID().Binary()), (worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(), - language_, worker_server_.GetPort())); + language_, core_worker_server_.GetPort())); // Unfortunately the raylet client has to be constructed after the receivers. - if (direct_actor_task_receiver_ != nullptr) { - direct_actor_task_receiver_->Init(*raylet_client_); + if (direct_task_receiver_ != nullptr) { + direct_task_receiver_->Init(*raylet_client_); } // Set timer to periodically send heartbeats containing active object IDs to the raylet. @@ -149,7 +152,6 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, plasma_store_provider_.reset( new CoreWorkerPlasmaStoreProvider(store_socket, raylet_client_, check_signals_)); - memory_store_provider_.reset(new CoreWorkerMemoryStoreProvider(memory_store_)); // Create an entry for the driver task in the task table. This task is // added immediately with status RUNNING. This allows us to push errors @@ -162,10 +164,10 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, std::vector empty_descriptor; std::unordered_map empty_resources; const TaskID task_id = TaskID::ForDriverTask(worker_context_.GetCurrentJobID()); - builder.SetCommonTaskSpec(task_id, language_, empty_descriptor, - worker_context_.GetCurrentJobID(), - TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()), - 0, GetCallerId(), 0, empty_resources, empty_resources); + builder.SetCommonTaskSpec( + task_id, language_, empty_descriptor, worker_context_.GetCurrentJobID(), + TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()), 0, GetCallerId(), 0, + false, empty_resources, empty_resources); std::shared_ptr data = std::make_shared(); data->mutable_task()->mutable_task_spec()->CopyFrom(builder.Build().GetMessage()); @@ -173,11 +175,18 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, SetCurrentTaskId(task_id); } - // TODO(edoakes): why don't we just share the memory store provider? direct_actor_submitter_ = std::unique_ptr( - new CoreWorkerDirectActorTaskSubmitter( - io_service_, std::unique_ptr( - new CoreWorkerMemoryStoreProvider(memory_store_)))); + new CoreWorkerDirectActorTaskSubmitter(*client_call_manager_, + memory_store_provider_)); + + direct_task_submitter_ = + std::unique_ptr(new CoreWorkerDirectTaskSubmitter( + *raylet_client_, + [this](WorkerAddress addr) { + return std::shared_ptr(new rpc::CoreWorkerClient( + addr.first, addr.second, *client_call_manager_)); + }, + memory_store_provider_)); } CoreWorker::~CoreWorker() { @@ -308,9 +317,9 @@ Status CoreWorker::Get(const std::vector &ids, const int64_t timeout_m local_timeout_ms = std::max(static_cast(0), timeout_ms - (current_time_ms() - start_time)); } - RAY_RETURN_NOT_OK(memory_store_provider_->Get(memory_object_ids, local_timeout_ms, - worker_context_.GetCurrentTaskID(), - &result_map, &got_exception)); + RAY_RETURN_NOT_OK(memory_store_provider_.Get(memory_object_ids, local_timeout_ms, + worker_context_.GetCurrentTaskID(), + &result_map, &got_exception)); } // If any of the objects have been promoted to plasma, then we retry their @@ -334,9 +343,9 @@ Status CoreWorker::Get(const std::vector &ids, const int64_t timeout_m for (const auto &id : promoted_plasma_ids) { auto it = result_map.find(id); if (it == result_map.end()) { - result_map.erase(id.WithTransportType(TaskTransportType::DIRECT_ACTOR)); + result_map.erase(id.WithTransportType(TaskTransportType::DIRECT)); } else { - result_map[id.WithTransportType(TaskTransportType::DIRECT_ACTOR)] = it->second; + result_map[id.WithTransportType(TaskTransportType::DIRECT)] = it->second; } result_map.erase(id); } @@ -360,10 +369,10 @@ Status CoreWorker::Get(const std::vector &ids, const int64_t timeout_m Status CoreWorker::Contains(const ObjectID &object_id, bool *has_object) { bool found = false; - if (object_id.IsDirectActorType()) { + if (object_id.IsDirectCallType()) { // Note that the memory store returns false if the object value is // ErrorType::OBJECT_IN_PLASMA. - RAY_RETURN_NOT_OK(memory_store_provider_->Contains(object_id, &found)); + RAY_RETURN_NOT_OK(memory_store_provider_.Contains(object_id, &found)); } if (!found) { // We check plasma as a fallback in all cases, since a direct call object @@ -413,7 +422,7 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, if (memory_object_ids.size() > 0) { // TODO(ekl) for memory objects that are ErrorType::OBJECT_IN_PLASMA, we should // consider waiting on them in plasma as well to ensure they are local. - RAY_RETURN_NOT_OK(memory_store_provider_->Wait( + RAY_RETURN_NOT_OK(memory_store_provider_.Wait( memory_object_ids, std::max(0, static_cast(ready.size()) - num_objects), /*timeout_ms=*/0, worker_context_.GetCurrentTaskID(), &ready)); } @@ -431,8 +440,8 @@ Status CoreWorker::Wait(const std::vector &ids, int num_objects, } if (memory_object_ids.size() > 0) { RAY_RETURN_NOT_OK( - memory_store_provider_->Wait(memory_object_ids, num_objects, timeout_ms, - worker_context_.GetCurrentTaskID(), &ready)); + memory_store_provider_.Wait(memory_object_ids, num_objects, timeout_ms, + worker_context_.GetCurrentTaskID(), &ready)); } } @@ -453,7 +462,7 @@ Status CoreWorker::Delete(const std::vector &object_ids, bool local_on RAY_RETURN_NOT_OK(plasma_store_provider_->Delete(plasma_object_ids, local_only, delete_creating_tasks)); - RAY_RETURN_NOT_OK(memory_store_provider_->Delete(memory_object_ids)); + RAY_RETURN_NOT_OK(memory_store_provider_.Delete(memory_object_ids)); return Status::OK(); } @@ -494,8 +503,7 @@ Status CoreWorker::SubmitTaskToRaylet(const TaskSpecification &task_spec) { if (task_deps->size() > 0) { for (size_t i = 0; i < num_returns; i++) { - reference_counter_.SetDependencies(task_spec.ReturnId(i, TaskTransportType::RAYLET), - task_deps); + reference_counter_.SetDependencies(task_spec.ReturnIdForPlasma(i), task_deps); } } @@ -511,11 +519,19 @@ Status CoreWorker::SubmitTask(const RayFunction &function, const auto task_id = TaskID::ForNormalTask(worker_context_.GetCurrentJobID(), worker_context_.GetCurrentTaskID(), next_task_index); - BuildCommonTaskSpec(builder, worker_context_.GetCurrentJobID(), task_id, - worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), - function, args, task_options.num_returns, task_options.resources, - {}, TaskTransportType::RAYLET, return_ids); - return SubmitTaskToRaylet(builder.Build()); + + // TODO(ekl) offload task building onto a thread pool for performance + BuildCommonTaskSpec( + builder, worker_context_.GetCurrentJobID(), task_id, + worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), function, args, + task_options.num_returns, task_options.resources, {}, + task_options.is_direct_call ? TaskTransportType::DIRECT : TaskTransportType::RAYLET, + return_ids); + if (task_options.is_direct_call) { + return direct_task_submitter_->SubmitTask(builder.Build()); + } else { + return raylet_client_->SubmitTask(builder.Build()); + } } Status CoreWorker::CreateActor(const RayFunction &function, @@ -562,7 +578,7 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f const bool is_direct_call = actor_handle->IsDirectCallActor(); const TaskTransportType transport_type = - is_direct_call ? TaskTransportType::DIRECT_ACTOR : TaskTransportType::RAYLET; + is_direct_call ? TaskTransportType::DIRECT : TaskTransportType::RAYLET; // Build common task spec. TaskSpecBuilder builder; @@ -675,7 +691,7 @@ Status CoreWorker::AllocateReturnObjects( bool object_already_exists = false; std::shared_ptr data_buffer; if (data_sizes[i] > 0) { - if (worker_context_.CurrentActorUseDirectCall() && + if (worker_context_.CurrentTaskIsDirectCall() && static_cast(data_sizes[i]) < RayConfig::instance().max_direct_call_object_size()) { data_buffer = std::make_shared(data_sizes[i]); @@ -710,8 +726,8 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, std::vector arg_reference_ids; RAY_CHECK_OK(BuildArgsForExecutor(task_spec, &args, &arg_reference_ids)); - const auto transport_type = worker_context_.CurrentActorUseDirectCall() - ? TaskTransportType::DIRECT_ACTOR + const auto transport_type = worker_context_.CurrentTaskIsDirectCall() + ? TaskTransportType::DIRECT : TaskTransportType::RAYLET; std::vector return_ids; for (size_t i = 0; i < task_spec.NumReturns(); i++) { @@ -745,7 +761,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, RAY_LOG(FATAL) << "Task " << task_spec.TaskId() << " failed to seal object " << return_ids[i] << " in store: " << status.message(); } - } else if (!worker_context_.CurrentActorUseDirectCall()) { + } else if (!worker_context_.CurrentTaskIsDirectCall()) { if (!Put(*return_objects->at(i), return_ids[i]).ok()) { RAY_LOG(FATAL) << "Task " << task_spec.TaskId() << " failed to put object " << return_ids[i] << " in store: " << status.message(); @@ -816,7 +832,7 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, void CoreWorker::HandleAssignTask(const rpc::AssignTaskRequest &request, rpc::AssignTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { - if (worker_context_.CurrentActorUseDirectCall()) { + if (worker_context_.CurrentTaskIsDirectCall()) { send_reply_callback(Status::Invalid("This actor only accepts direct calls."), nullptr, nullptr); return; @@ -827,12 +843,11 @@ void CoreWorker::HandleAssignTask(const rpc::AssignTaskRequest &request, } } -void CoreWorker::HandleDirectActorAssignTask( - const rpc::DirectActorAssignTaskRequest &request, - rpc::DirectActorAssignTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { +void CoreWorker::HandlePushTask(const rpc::PushTaskRequest &request, + rpc::PushTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) { task_execution_service_.post([=] { - direct_actor_task_receiver_->HandleDirectActorAssignTask(request, reply, - send_reply_callback); + direct_task_receiver_->HandlePushTask(request, reply, send_reply_callback); }); } @@ -841,9 +856,19 @@ void CoreWorker::HandleDirectActorCallArgWaitComplete( rpc::DirectActorCallArgWaitCompleteReply *reply, rpc::SendReplyCallback send_reply_callback) { task_execution_service_.post([=] { - direct_actor_task_receiver_->HandleDirectActorCallArgWaitComplete( - request, reply, send_reply_callback); + direct_task_receiver_->HandleDirectActorCallArgWaitComplete(request, reply, + send_reply_callback); }); } +void CoreWorker::HandleWorkerLeaseGranted(const rpc::WorkerLeaseGrantedRequest &request, + rpc::WorkerLeaseGrantedReply *reply, + rpc::SendReplyCallback send_reply_callback) { + // Run this directly since the main thread may be tied up processing a task and + // we need to still continue processing these scheduling operations in the backend. + direct_task_submitter_->HandleWorkerLeaseGranted( + std::make_pair(request.address(), request.port())); + send_reply_callback(Status::OK(), nullptr, nullptr); +} + } // namespace ray diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 2a8d6db32..24ebe42c4 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -14,22 +14,24 @@ #include "ray/core_worker/store_provider/memory_store_provider.h" #include "ray/core_worker/store_provider/plasma_store_provider.h" #include "ray/core_worker/transport/direct_actor_transport.h" +#include "ray/core_worker/transport/direct_task_transport.h" #include "ray/core_worker/transport/raylet_transport.h" #include "ray/gcs/redis_gcs_client.h" #include "ray/raylet/raylet_client.h" #include "ray/rpc/node_manager/node_manager_client.h" -#include "ray/rpc/worker/worker_client.h" -#include "ray/rpc/worker/worker_server.h" +#include "ray/rpc/worker/core_worker_client.h" +#include "ray/rpc/worker/core_worker_server.h" /// The set of gRPC handlers and their associated level of concurrency. If you want to /// add a new call to the worker gRPC server, do the following: -/// 1) Add the rpc to the WorkerService in core_worker.proto, e.g., "ExampleCall" +/// 1) Add the rpc to the CoreWorkerService in core_worker.proto, e.g., "ExampleCall" /// 2) Add a new handler to the macro below: "RAY_CORE_WORKER_RPC_HANDLER(ExampleCall, 1)" /// 3) Add a method to the CoreWorker class below: "CoreWorker::HandleExampleCall" -#define RAY_CORE_WORKER_RPC_HANDLERS \ - RAY_CORE_WORKER_RPC_HANDLER(AssignTask, 5) \ - RAY_CORE_WORKER_RPC_HANDLER(DirectActorAssignTask, 9999) \ - RAY_CORE_WORKER_RPC_HANDLER(DirectActorCallArgWaitComplete, 100) +#define RAY_CORE_WORKER_RPC_HANDLERS \ + RAY_CORE_WORKER_RPC_HANDLER(AssignTask, 5) \ + RAY_CORE_WORKER_RPC_HANDLER(PushTask, 9999) \ + RAY_CORE_WORKER_RPC_HANDLER(DirectActorCallArgWaitComplete, 100) \ + RAY_CORE_WORKER_RPC_HANDLER(WorkerLeaseGranted, 5) namespace ray { @@ -319,29 +321,32 @@ class CoreWorker { const std::vector> &metadatas, std::vector> *return_objects); - /* Handlers for the worker's gRPC server. These are executed on the io_service_ and post - * work to the appropriate event loop. + /** + * The following methods are handlers for the core worker's gRPC server, which follow + * a macro-generated call convention. These are executed on the io_service_ and + * post work to the appropriate event loop. */ - /// Handle an "AssignTask" event corresponding to scheduling a normal or an actor task - /// on this worker from the raylet. + /// Implements gRPC server handler. void HandleAssignTask(const rpc::AssignTaskRequest &request, rpc::AssignTaskReply *reply, rpc::SendReplyCallback send_reply_callback); - /// Handle a "DirectActorAssignTask" event corresponding to scheduling an actor task - /// on this worker from another worker. - void HandleDirectActorAssignTask(const rpc::DirectActorAssignTaskRequest &request, - rpc::DirectActorAssignTaskReply *reply, - rpc::SendReplyCallback send_reply_callback); + /// Implements gRPC server handler. + void HandlePushTask(const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply, + rpc::SendReplyCallback send_reply_callback); - /// Handle a "DirectActorAssignTask" event corresponding to the raylet notifiying this - /// worker that an argument is ready. + /// Implements gRPC server handler. void HandleDirectActorCallArgWaitComplete( const rpc::DirectActorCallArgWaitCompleteRequest &request, rpc::DirectActorCallArgWaitCompleteReply *reply, rpc::SendReplyCallback send_reply_callback); + /// Implements gRPC server handler. + void HandleWorkerLeaseGranted(const rpc::WorkerLeaseGrantedRequest &request, + rpc::WorkerLeaseGrantedReply *reply, + rpc::SendReplyCallback send_reply_callback); + private: /// Run the io_service_ event loop. This should be called in a background thread. void RunIOService(); @@ -446,19 +451,19 @@ class CoreWorker { /// Keeps the io_service_ alive. boost::asio::io_service::work io_work_; + /// Shared client call manager. + std::unique_ptr client_call_manager_; + /// Timer used to periodically send heartbeat containing active object IDs to the /// raylet. boost::asio::steady_timer heartbeat_timer_; /// RPC server used to receive tasks to execute. - rpc::GrpcServer worker_server_; + rpc::GrpcServer core_worker_server_; // Client to the GCS shared by core worker interfaces. gcs::RedisGcsClient gcs_client_; - /// The `ClientCallManager` object that is shared by all `NodeManagerClient`s. - rpc::ClientCallManager client_call_manager_; - // Client to the raylet shared by core worker interfaces. std::unique_ptr raylet_client_; @@ -479,7 +484,7 @@ class CoreWorker { std::unique_ptr plasma_store_provider_; /// In-memory store interface. - std::unique_ptr memory_store_provider_; + CoreWorkerMemoryStoreProvider memory_store_provider_; /// /// Fields related to task submission. @@ -488,6 +493,9 @@ class CoreWorker { // Interface to submit tasks directly to other actors. std::unique_ptr direct_actor_submitter_; + // Interface to submit non-actor tasks directly to leased workers. + std::unique_ptr direct_task_submitter_; + /// Map from actor ID to a handle to that actor. absl::flat_hash_map> actor_handles_; @@ -519,10 +527,10 @@ class CoreWorker { std::unique_ptr raylet_task_receiver_; /// Common rpc service for all worker modules. - rpc::WorkerGrpcService grpc_service_; + rpc::CoreWorkerGrpcService grpc_service_; // Interface that receives tasks from direct actor calls. - std::unique_ptr direct_actor_task_receiver_; + std::unique_ptr direct_task_receiver_; friend class CoreWorkerTest; }; diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.cc b/src/ray/core_worker/store_provider/memory_store/memory_store.cc index 0f0627d13..0084778da 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.cc +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.cc @@ -106,32 +106,66 @@ std::shared_ptr GetRequest::Get(const ObjectID &object_id) const { CoreWorkerMemoryStore::CoreWorkerMemoryStore() {} -Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &object) { - std::unique_lock lock(lock_); - auto iter = objects_.find(object_id); - if (iter != objects_.end()) { - return Status::ObjectExists("object already exists in the memory store"); +void CoreWorkerMemoryStore::GetAsync( + const ObjectID &object_id, std::function)> callback) { + std::shared_ptr ptr; + { + absl::MutexLock lock(&mu_); + auto iter = objects_.find(object_id); + if (iter != objects_.end()) { + ptr = iter->second; + } else { + object_async_get_requests_[object_id].push_back(callback); + } } + // It's important for performance to run the callback outside the lock. + if (ptr != nullptr) { + callback(ptr); + } +} +Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &object) { + std::vector)>> async_callbacks; auto object_entry = std::make_shared(object.GetData(), object.GetMetadata(), true); - bool should_add_entry = true; - auto object_request_iter = object_get_requests_.find(object_id); - if (object_request_iter != object_get_requests_.end()) { - auto &get_requests = object_request_iter->second; - for (auto &get_request : get_requests) { - get_request->Set(object_id, object_entry); - if (get_request->ShouldRemoveObjects()) { - should_add_entry = false; + { + absl::MutexLock lock(&mu_); + auto iter = objects_.find(object_id); + if (iter != objects_.end()) { + return Status::ObjectExists("object already exists in the memory store"); + } + + auto async_callback_it = object_async_get_requests_.find(object_id); + if (async_callback_it != object_async_get_requests_.end()) { + auto &callbacks = async_callback_it->second; + async_callbacks = std::move(callbacks); + object_async_get_requests_.erase(async_callback_it); + } + + bool should_add_entry = true; + auto object_request_iter = object_get_requests_.find(object_id); + if (object_request_iter != object_get_requests_.end()) { + auto &get_requests = object_request_iter->second; + for (auto &get_request : get_requests) { + get_request->Set(object_id, object_entry); + if (get_request->ShouldRemoveObjects()) { + should_add_entry = false; + } } } + + if (should_add_entry) { + // If there is no existing get request, then add the `RayObject` to map. + objects_.emplace(object_id, object_entry); + } } - if (should_add_entry) { - // If there is no existing get request, then add the `RayObject` to map. - objects_.emplace(object_id, object_entry); + // It's important for performance to run the callbacks outside the lock. + for (const auto &cb : async_callbacks) { + cb(object_entry); } + return Status::OK(); } @@ -147,7 +181,7 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, absl::flat_hash_set remaining_ids; absl::flat_hash_set ids_to_remove; - std::unique_lock lock(lock_); + absl::MutexLock lock(&mu_); // Check for existing objects and see if this get request can be fullfilled. for (size_t i = 0; i < object_ids.size(); i++) { const auto &object_id = object_ids[i]; @@ -192,7 +226,7 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, get_request->Wait(timeout_ms); { - std::unique_lock lock(lock_); + absl::MutexLock lock(&mu_); // Populate results. for (size_t i = 0; i < object_ids.size(); i++) { const auto &object_id = object_ids[i]; @@ -223,14 +257,14 @@ Status CoreWorkerMemoryStore::Get(const std::vector &object_ids, } void CoreWorkerMemoryStore::Delete(const std::vector &object_ids) { - std::unique_lock lock(lock_); + absl::MutexLock lock(&mu_); for (const auto &object_id : object_ids) { objects_.erase(object_id); } } bool CoreWorkerMemoryStore::Contains(const ObjectID &object_id) { - std::unique_lock lock(lock_); + absl::MutexLock lock(&mu_); auto it = objects_.find(object_id); // If obj is in plasma, we defer to the plasma store for the Contains() call. return it != objects_.end() && !it->second->IsInPlasmaError(); diff --git a/src/ray/core_worker/store_provider/memory_store/memory_store.h b/src/ray/core_worker/store_provider/memory_store/memory_store.h index bd9596c43..415da008e 100644 --- a/src/ray/core_worker/store_provider/memory_store/memory_store.h +++ b/src/ray/core_worker/store_provider/memory_store/memory_store.h @@ -3,6 +3,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/synchronization/mutex.h" #include "ray/common/id.h" #include "ray/common/status.h" #include "ray/core_worker/common.h" @@ -39,6 +40,15 @@ class CoreWorkerMemoryStore { Status Get(const std::vector &object_ids, int num_objects, int64_t timeout_ms, bool remove_after_get, std::vector> *results); + /// Asynchronously get an object from the object store. The object will not be removed + /// from storage after GetAsync (TODO(ekl): integrate this with object GC). + /// + /// \param[in] object_id The object id to get. + /// \param[in] callback The callback to run with the reference to the retrieved + /// object value once available. + void GetAsync(const ObjectID &object_id, + std::function)> callback); + /// Delete a list of objects from the object store. /// /// \param[in] object_ids IDs of the objects to delete. @@ -53,14 +63,19 @@ class CoreWorkerMemoryStore { private: /// Map from object ID to `RayObject`. - absl::flat_hash_map> objects_; + absl::flat_hash_map> objects_ GUARDED_BY(mu_); /// Map from object ID to its get requests. absl::flat_hash_map>> - object_get_requests_; + object_get_requests_ GUARDED_BY(mu_); + + /// Map from object ID to its async get requests. + absl::flat_hash_map)>>> + object_async_get_requests_ GUARDED_BY(mu_); /// Protect the two maps above. - std::mutex lock_; + absl::Mutex mu_; }; } // namespace ray diff --git a/src/ray/core_worker/store_provider/memory_store_provider.h b/src/ray/core_worker/store_provider/memory_store_provider.h index 2ae0d66e3..76ac204dd 100644 --- a/src/ray/core_worker/store_provider/memory_store_provider.h +++ b/src/ray/core_worker/store_provider/memory_store_provider.h @@ -19,6 +19,11 @@ class CoreWorkerMemoryStoreProvider { public: CoreWorkerMemoryStoreProvider(std::shared_ptr store); + void GetAsync(const ObjectID &object_id, + std::function)> callback) { + store_->GetAsync(object_id, callback); + } + Status Put(const RayObject &object, const ObjectID &object_id); Status Get(const absl::flat_hash_set &object_ids, int64_t timeout_ms, diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index c9c4c3bd5..22350999a 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -44,16 +44,6 @@ static void flushall_redis(void) { redisFree(context); } -std::shared_ptr GenerateRandomBuffer() { - auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count(); - std::mt19937 gen(seed); - std::uniform_int_distribution<> dis(1, 10); - std::uniform_int_distribution<> value_dis(1, 255); - - std::vector arg1(dis(gen), value_dis(gen)); - return std::make_shared(arg1.data(), arg1.size(), true); -} - ActorID CreateActorHelper(CoreWorker &worker, std::unordered_map &resources, bool is_direct_call, uint64_t max_reconstructions) { @@ -279,16 +269,15 @@ void CoreWorkerTest::TestActorTask(std::unordered_map &reso args.emplace_back( TaskArg::PassByValue(std::make_shared(buffer2, nullptr))); - TaskOptions options{1, resources}; + TaskOptions options{1, false, resources}; std::vector return_ids; RayFunction func(ray::Language::PYTHON, {}); RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids)); ASSERT_EQ(return_ids.size(), 1); ASSERT_TRUE(return_ids[0].IsReturnObject()); - ASSERT_EQ( - static_cast(return_ids[0].GetTransportType()), - is_direct_call ? TaskTransportType::DIRECT_ACTOR : TaskTransportType::RAYLET); + ASSERT_EQ(static_cast(return_ids[0].GetTransportType()), + is_direct_call ? TaskTransportType::DIRECT : TaskTransportType::RAYLET); std::vector> results; RAY_CHECK_OK(driver.Get(return_ids, -1, &results)); @@ -320,7 +309,7 @@ void CoreWorkerTest::TestActorTask(std::unordered_map &reso args.emplace_back( TaskArg::PassByValue(std::make_shared(buffer2, nullptr))); - TaskOptions options{1, resources}; + TaskOptions options{1, false, resources}; std::vector return_ids; RayFunction func(ray::Language::PYTHON, {}); auto status = driver.SubmitActorTask(actor_id, func, args, options, &return_ids); @@ -380,7 +369,7 @@ void CoreWorkerTest::TestActorReconstruction( args.emplace_back( TaskArg::PassByValue(std::make_shared(buffer1, nullptr))); - TaskOptions options{1, resources}; + TaskOptions options{1, false, resources}; std::vector return_ids; RayFunction func(ray::Language::PYTHON, {}); @@ -425,7 +414,7 @@ void CoreWorkerTest::TestActorFailure(std::unordered_map &r args.emplace_back( TaskArg::PassByValue(std::make_shared(buffer1, nullptr))); - TaskOptions options{1, resources}; + TaskOptions options{1, false, resources}; std::vector return_ids; RayFunction func(ray::Language::PYTHON, {}); @@ -486,7 +475,7 @@ TEST_F(ZeroNodeTest, TestTaskArg) { ASSERT_EQ(*data, *buffer); } -// Performance batchmark for `DirectActorAssignTaskRequest` creation. +// Performance batchmark for `PushTaskRequest` creation. TEST_F(ZeroNodeTest, TestTaskSpecPerf) { // Create a dummy actor handle, and then create a number of `TaskSpec` // to benchmark performance. @@ -505,20 +494,21 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { function.GetFunctionDescriptor()); // Manually create `num_tasks` task specs, and for each of them create a - // `DirectActorAssignTaskRequest`, this is to batch performance of TaskSpec + // `PushTaskRequest`, this is to batch performance of TaskSpec // creation/copy/destruction. int64_t start_ms = current_time_ms(); const auto num_tasks = 10000 * 10; - RAY_LOG(INFO) << "start creating " << num_tasks << " DirectActorAssignTaskRequests"; + RAY_LOG(INFO) << "start creating " << num_tasks << " PushTaskRequests"; for (int i = 0; i < num_tasks; i++) { - TaskOptions options{1, resources}; + TaskOptions options{1, false, resources}; std::vector return_ids; auto num_returns = options.num_returns; TaskSpecBuilder builder; builder.SetCommonTaskSpec(RandomTaskId(), function.GetLanguage(), function.GetFunctionDescriptor(), job_id, RandomTaskId(), 0, - RandomTaskId(), num_returns, resources, resources); + RandomTaskId(), num_returns, /*is_direct*/ false, resources, + resources); // Set task arguments. for (const auto &arg : args) { if (arg.IsPassedByReference()) { @@ -531,14 +521,13 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { actor_handle.SetActorTaskSpec(builder, TaskTransportType::RAYLET, ObjectID::FromRandom()); - const auto &task_spec = builder.Build(); + auto task_spec = builder.Build(); ASSERT_TRUE(task_spec.IsActorTask()); - auto request = std::unique_ptr( - new rpc::DirectActorAssignTaskRequest); + auto request = std::unique_ptr(new rpc::PushTaskRequest); request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage()); } - RAY_LOG(INFO) << "Finish creating " << num_tasks << " DirectActorAssignTaskRequests" + RAY_LOG(INFO) << "Finish creating " << num_tasks << " PushTaskRequests" << ", which takes " << current_time_ms() - start_ms << " ms"; } @@ -566,7 +555,7 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) { sizeof(array)); args.emplace_back(TaskArg::PassByValue(std::make_shared(buffer, nullptr))); - TaskOptions options{1, resources}; + TaskOptions options{1, false, resources}; std::vector return_ids; RayFunction func(ray::Language::PYTHON, {}); diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc new file mode 100644 index 000000000..bf2025b11 --- /dev/null +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -0,0 +1,272 @@ +#include "gtest/gtest.h" + +#include "ray/common/task/task_spec.h" +#include "ray/core_worker/store_provider/memory_store/memory_store.h" +#include "ray/core_worker/store_provider/memory_store_provider.h" +#include "ray/core_worker/transport/direct_task_transport.h" +#include "ray/raylet/raylet_client.h" +#include "ray/rpc/worker/core_worker_client.h" +#include "src/ray/util/test_util.h" + +namespace ray { + +class MockWorkerClient : public rpc::CoreWorkerClientInterface { + public: + ray::Status PushNormalTask( + std::unique_ptr request, + const rpc::ClientCallback &callback) override { + callbacks.push_back(callback); + return Status::OK(); + } + + std::vector> callbacks; +}; + +class MockRayletClient : public WorkerLeaseInterface { + public: + ray::Status ReturnWorker(int worker_port) { + num_workers_returned += 1; + return Status::OK(); + } + + ray::Status RequestWorkerLease(const ray::TaskSpecification &resource_spec) { + num_workers_requested += 1; + return Status::OK(); + } + + int num_workers_requested = 0; + int num_workers_returned = 0; +}; + +TEST(LocalDependencyResolverTest, TestNoDependencies) { + auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); + CoreWorkerMemoryStoreProvider store(ptr); + LocalDependencyResolver resolver(store); + TaskSpecification task; + bool ok = false; + resolver.ResolveDependencies(task, [&ok]() { ok = true; }); + ASSERT_TRUE(ok); +} + +TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) { + auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); + CoreWorkerMemoryStoreProvider store(ptr); + LocalDependencyResolver resolver(store); + ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::RAYLET); + TaskSpecification task; + task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); + bool ok = false; + resolver.ResolveDependencies(task, [&ok]() { ok = true; }); + // We ignore and don't block on plasma dependencies. + ASSERT_TRUE(ok); + ASSERT_EQ(resolver.NumPendingTasks(), 0); +} + +TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) { + auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); + CoreWorkerMemoryStoreProvider store(ptr); + LocalDependencyResolver resolver(store); + ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); + ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); + auto data = GenerateRandomObject(); + // Ensure the data is already present in the local store. + ASSERT_TRUE(store.Put(*data, obj1).ok()); + ASSERT_TRUE(store.Put(*data, obj2).ok()); + TaskSpecification task; + task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); + task.GetMutableMessage().add_args()->add_object_ids(obj2.Binary()); + bool ok = false; + resolver.ResolveDependencies(task, [&ok]() { ok = true; }); + // Tests that the task proto was rewritten to have inline argument values. + ASSERT_TRUE(ok); + ASSERT_FALSE(task.ArgByRef(0)); + ASSERT_FALSE(task.ArgByRef(1)); + ASSERT_NE(task.ArgData(0), nullptr); + ASSERT_NE(task.ArgData(1), nullptr); + ASSERT_EQ(resolver.NumPendingTasks(), 0); +} + +TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) { + auto ptr = std::shared_ptr(new CoreWorkerMemoryStore()); + CoreWorkerMemoryStoreProvider store(ptr); + LocalDependencyResolver resolver(store); + ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); + ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT); + auto data = GenerateRandomObject(); + TaskSpecification task; + task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary()); + task.GetMutableMessage().add_args()->add_object_ids(obj2.Binary()); + bool ok = false; + resolver.ResolveDependencies(task, [&ok]() { ok = true; }); + ASSERT_EQ(resolver.NumPendingTasks(), 1); + ASSERT_TRUE(!ok); + ASSERT_TRUE(store.Put(*data, obj1).ok()); + ASSERT_TRUE(store.Put(*data, obj2).ok()); + // Tests that the task proto was rewritten to have inline argument values after + // resolution completes. + ASSERT_TRUE(ok); + ASSERT_FALSE(task.ArgByRef(0)); + ASSERT_FALSE(task.ArgByRef(1)); + ASSERT_NE(task.ArgData(0), nullptr); + ASSERT_NE(task.ArgData(1), nullptr); + ASSERT_EQ(resolver.NumPendingTasks(), 0); +} + +TEST(DirectTaskTranportTest, TestSubmitOneTask) { + MockRayletClient raylet_client; + auto worker_client = std::shared_ptr(new MockWorkerClient()); + auto store = std::shared_ptr(new CoreWorkerMemoryStore()); + auto factory = [&](WorkerAddress addr) { return worker_client; }; + CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store); + TaskSpecification task; + task.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); + + ASSERT_TRUE(submitter.SubmitTask(task).ok()); + ASSERT_EQ(raylet_client.num_workers_requested, 1); + ASSERT_EQ(raylet_client.num_workers_returned, 0); + ASSERT_EQ(worker_client->callbacks.size(), 0); + + submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1234)); + ASSERT_EQ(worker_client->callbacks.size(), 1); + + worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply()); + ASSERT_EQ(raylet_client.num_workers_returned, 1); +} + +TEST(DirectTaskTranportTest, TestHandleTaskFailure) { + MockRayletClient raylet_client; + auto worker_client = std::shared_ptr(new MockWorkerClient()); + auto store = std::shared_ptr(new CoreWorkerMemoryStore()); + auto factory = [&](WorkerAddress addr) { return worker_client; }; + CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store); + TaskSpecification task; + task.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); + + ASSERT_TRUE(submitter.SubmitTask(task).ok()); + submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1234)); + // Simulate a system failure, i.e., worker died unexpectedly. + worker_client->callbacks[0](Status::IOError("oops"), rpc::PushTaskReply()); + ASSERT_EQ(worker_client->callbacks.size(), 1); + ASSERT_EQ(raylet_client.num_workers_returned, 1); +} + +TEST(DirectTaskTranportTest, TestConcurrentWorkerLeases) { + MockRayletClient raylet_client; + auto worker_client = std::shared_ptr(new MockWorkerClient()); + auto store = std::shared_ptr(new CoreWorkerMemoryStore()); + auto factory = [&](WorkerAddress addr) { return worker_client; }; + CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store); + TaskSpecification task1; + TaskSpecification task2; + TaskSpecification task3; + task1.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); + task2.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); + task3.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); + + ASSERT_TRUE(submitter.SubmitTask(task1).ok()); + ASSERT_TRUE(submitter.SubmitTask(task2).ok()); + ASSERT_TRUE(submitter.SubmitTask(task3).ok()); + ASSERT_EQ(raylet_client.num_workers_requested, 1); + + // Task 1 is pushed; worker 2 is requested. + submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1000)); + ASSERT_EQ(worker_client->callbacks.size(), 1); + ASSERT_EQ(raylet_client.num_workers_requested, 2); + + // Task 2 is pushed; worker 3 is requested. + submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1001)); + ASSERT_EQ(worker_client->callbacks.size(), 2); + ASSERT_EQ(raylet_client.num_workers_requested, 3); + + // Task 3 is pushed; no more workers requested. + submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1002)); + ASSERT_EQ(worker_client->callbacks.size(), 3); + ASSERT_EQ(raylet_client.num_workers_requested, 3); + + // All workers returned. + for (const auto &cb : worker_client->callbacks) { + cb(Status::OK(), rpc::PushTaskReply()); + } + ASSERT_EQ(raylet_client.num_workers_returned, 3); +} + +TEST(DirectTaskTranportTest, TestReuseWorkerLease) { + MockRayletClient raylet_client; + auto worker_client = std::shared_ptr(new MockWorkerClient()); + auto store = std::shared_ptr(new CoreWorkerMemoryStore()); + auto factory = [&](WorkerAddress addr) { return worker_client; }; + CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store); + TaskSpecification task1; + TaskSpecification task2; + TaskSpecification task3; + task1.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); + task2.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); + task3.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); + + ASSERT_TRUE(submitter.SubmitTask(task1).ok()); + ASSERT_TRUE(submitter.SubmitTask(task2).ok()); + ASSERT_TRUE(submitter.SubmitTask(task3).ok()); + ASSERT_EQ(raylet_client.num_workers_requested, 1); + + // Task 1 is pushed. + submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1000)); + ASSERT_EQ(worker_client->callbacks.size(), 1); + ASSERT_EQ(raylet_client.num_workers_requested, 2); + + // Task 1 finishes, Task 2 is scheduled on the same worker. + worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply()); + ASSERT_EQ(worker_client->callbacks.size(), 2); + ASSERT_EQ(raylet_client.num_workers_returned, 0); + + // Task 2 finishes, Task 3 is scheduled on the same worker. + worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply()); + ASSERT_EQ(worker_client->callbacks.size(), 3); + ASSERT_EQ(raylet_client.num_workers_returned, 0); + + // Task 3 finishes, the worker is returned. + worker_client->callbacks[2](Status::OK(), rpc::PushTaskReply()); + ASSERT_EQ(raylet_client.num_workers_returned, 1); + + // The second lease request is returned immediately. + submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1001)); + ASSERT_EQ(raylet_client.num_workers_returned, 2); +} + +TEST(DirectTaskTranportTest, TestWorkerNotReusedOnError) { + MockRayletClient raylet_client; + auto worker_client = std::shared_ptr(new MockWorkerClient()); + auto store = std::shared_ptr(new CoreWorkerMemoryStore()); + auto factory = [&](WorkerAddress addr) { return worker_client; }; + CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store); + TaskSpecification task1; + TaskSpecification task2; + task1.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); + task2.GetMutableMessage().set_task_id(TaskID::Nil().Binary()); + + ASSERT_TRUE(submitter.SubmitTask(task1).ok()); + ASSERT_TRUE(submitter.SubmitTask(task2).ok()); + ASSERT_EQ(raylet_client.num_workers_requested, 1); + + // Task 1 is pushed. + submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1000)); + ASSERT_EQ(worker_client->callbacks.size(), 1); + ASSERT_EQ(raylet_client.num_workers_requested, 2); + + // Task 1 finishes with failure; the worker is returned. + worker_client->callbacks[0](Status::IOError("worker dead"), rpc::PushTaskReply()); + ASSERT_EQ(worker_client->callbacks.size(), 1); + ASSERT_EQ(raylet_client.num_workers_returned, 1); + + // Task 2 runs successfully on the second worker. + submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1001)); + ASSERT_EQ(worker_client->callbacks.size(), 2); + worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply()); + ASSERT_EQ(raylet_client.num_workers_returned, 2); +} + +} // namespace ray + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/src/ray/core_worker/transport/direct_actor_transport.cc b/src/ray/core_worker/transport/direct_actor_transport.cc index a9e5ba8e4..002db44c7 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.cc +++ b/src/ray/core_worker/transport/direct_actor_transport.cc @@ -6,14 +6,11 @@ using ray::rpc::ActorTableData; namespace ray { CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter( - boost::asio::io_service &io_service, - std::unique_ptr store_provider) - : io_service_(io_service), - client_call_manager_(io_service), - in_memory_store_(std::move(store_provider)) {} + rpc::ClientCallManager &client_call_manager, + CoreWorkerMemoryStoreProvider store_provider) + : client_call_manager_(client_call_manager), in_memory_store_(store_provider) {} -Status CoreWorkerDirectActorTaskSubmitter::SubmitTask( - const TaskSpecification &task_spec) { +Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) { RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId(); RAY_CHECK(task_spec.IsActorTask()); @@ -22,8 +19,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask( const auto task_id = task_spec.TaskId(); const auto num_returns = task_spec.NumReturns(); - auto request = std::unique_ptr( - new rpc::DirectActorAssignTaskRequest); + auto request = std::unique_ptr(new rpc::PushTaskRequest); request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage()); std::unique_lock guard(mutex_); @@ -49,7 +45,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask( // Submit request. auto &client = rpc_clients_[actor_id]; - DirectActorAssignTask(*client, std::move(request), actor_id, task_id, num_returns); + PushActorTask(*client, std::move(request), actor_id, task_id, num_returns); } else { // Actor is dead, treat the task as failure. RAY_CHECK(iter->second.state_ == ActorTableData::DEAD); @@ -106,8 +102,8 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate( void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks( const ActorID &actor_id, std::string ip_address, int port) { - std::shared_ptr grpc_client = - std::make_shared(ip_address, port, client_call_manager_); + std::shared_ptr grpc_client = + std::make_shared(ip_address, port, client_call_manager_); RAY_CHECK(rpc_clients_.emplace(actor_id, std::move(grpc_client)).second); // Submit all pending requests. @@ -117,22 +113,20 @@ void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks( auto request = std::move(requests.front()); auto num_returns = request->task_spec().num_returns(); auto task_id = TaskID::FromBinary(request->task_spec().task_id()); - DirectActorAssignTask(*client, std::move(request), actor_id, task_id, num_returns); + PushActorTask(*client, std::move(request), actor_id, task_id, num_returns); requests.pop_front(); } } -void CoreWorkerDirectActorTaskSubmitter::DirectActorAssignTask( - rpc::WorkerTaskClient &client, - std::unique_ptr request, const ActorID &actor_id, - const TaskID &task_id, int num_returns) { +void CoreWorkerDirectActorTaskSubmitter::PushActorTask( + rpc::CoreWorkerClient &client, std::unique_ptr request, + const ActorID &actor_id, const TaskID &task_id, int num_returns) { RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id; waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns)); - auto status = client.DirectActorAssignTask( - std::move(request), - [this, actor_id, task_id, num_returns]( - Status status, const rpc::DirectActorAssignTaskReply &reply) { + auto status = client.PushActorTask( + std::move(request), [this, actor_id, task_id, num_returns]( + Status status, const rpc::PushTaskReply &reply) { { std::unique_lock guard(mutex_); waiting_reply_tasks_[actor_id].erase(task_id); @@ -156,7 +150,7 @@ void CoreWorkerDirectActorTaskSubmitter::DirectActorAssignTask( const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); RAY_CHECK_OK( - in_memory_store_->Put(RayObject(nullptr, meta_buffer), object_id)); + in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id)); } else { std::shared_ptr data_buffer; if (return_object.data().size() > 0) { @@ -172,8 +166,8 @@ void CoreWorkerDirectActorTaskSubmitter::DirectActorAssignTask( reinterpret_cast(return_object.metadata().data())), return_object.metadata().size()); } - RAY_CHECK_OK(in_memory_store_->Put(RayObject(data_buffer, metadata_buffer), - object_id)); + RAY_CHECK_OK( + in_memory_store_.Put(RayObject(data_buffer, metadata_buffer), object_id)); } } }); @@ -189,11 +183,11 @@ void CoreWorkerDirectActorTaskSubmitter::TreatTaskAsFailed( for (int i = 0; i < num_returns; i++) { const auto object_id = ObjectID::ForTaskReturn( task_id, /*index=*/i + 1, - /*transport_type=*/static_cast(TaskTransportType::DIRECT_ACTOR)); + /*transport_type=*/static_cast(TaskTransportType::DIRECT)); std::string meta = std::to_string(static_cast(error_type)); auto metadata = const_cast(reinterpret_cast(meta.data())); auto meta_buffer = std::make_shared(metadata, meta.size()); - RAY_CHECK_OK(in_memory_store_->Put(RayObject(nullptr, meta_buffer), object_id)); + RAY_CHECK_OK(in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id)); } } @@ -204,20 +198,19 @@ bool CoreWorkerDirectActorTaskSubmitter::IsActorAlive(const ActorID &actor_id) c return (iter != actor_states_.end() && iter->second.state_ == ActorTableData::ALIVE); } -CoreWorkerDirectActorTaskReceiver::CoreWorkerDirectActorTaskReceiver( +CoreWorkerDirectTaskReceiver::CoreWorkerDirectTaskReceiver( WorkerContext &worker_context, boost::asio::io_service &main_io_service, - rpc::GrpcServer &server, const TaskHandler &task_handler, - const std::function &exit_handler) + const TaskHandler &task_handler, const std::function &exit_handler) : worker_context_(worker_context), task_handler_(task_handler), exit_handler_(exit_handler), task_main_io_service_(main_io_service) {} -void CoreWorkerDirectActorTaskReceiver::Init(RayletClient &raylet_client) { +void CoreWorkerDirectTaskReceiver::Init(RayletClient &raylet_client) { waiter_.reset(new DependencyWaiterImpl(raylet_client)); } -void CoreWorkerDirectActorTaskReceiver::SetMaxActorConcurrency(int max_concurrency) { +void CoreWorkerDirectTaskReceiver::SetMaxActorConcurrency(int max_concurrency) { if (max_concurrency != max_concurrency_) { RAY_LOG(INFO) << "Creating new thread pool of size " << max_concurrency; RAY_CHECK(pool_ == nullptr) << "Cannot change max concurrency at runtime."; @@ -226,13 +219,13 @@ void CoreWorkerDirectActorTaskReceiver::SetMaxActorConcurrency(int max_concurren } } -void CoreWorkerDirectActorTaskReceiver::HandleDirectActorAssignTask( - const rpc::DirectActorAssignTaskRequest &request, - rpc::DirectActorAssignTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { +void CoreWorkerDirectTaskReceiver::HandlePushTask( + const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) { RAY_CHECK(waiter_ != nullptr) << "Must call init() prior to use"; const TaskSpecification task_spec(request.task_spec()); RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId(); - if (task_spec.IsActorTask() && !worker_context_.CurrentActorUseDirectCall()) { + if (task_spec.IsActorTask() && !worker_context_.CurrentTaskIsDirectCall()) { send_reply_callback(Status::Invalid("This actor doesn't accept direct calls."), nullptr, nullptr); return; @@ -257,62 +250,63 @@ void CoreWorkerDirectActorTaskReceiver::HandleDirectActorAssignTask( task_main_io_service_, *waiter_, pool_))); it = result.first; } - it->second->Add( - request.sequence_number(), request.client_processed_up_to(), - [this, reply, send_reply_callback, task_spec]() { - auto num_returns = task_spec.NumReturns(); - RAY_CHECK(task_spec.IsActorCreationTask() || task_spec.IsActorTask()); - RAY_CHECK(num_returns > 0); - // Decrease to account for the dummy object id. - num_returns--; + it->second->Add(request.sequence_number(), request.client_processed_up_to(), + [this, reply, send_reply_callback, task_spec]() { + auto num_returns = task_spec.NumReturns(); + RAY_CHECK(num_returns > 0); + if (task_spec.IsActorCreationTask() || task_spec.IsActorTask()) { + // Decrease to account for the dummy object id. + num_returns--; + } - // TODO(edoakes): resource IDs are currently kept track of in the raylet, - // need to come up with a solution for this. - ResourceMappingType resource_ids; - std::vector> return_objects; - auto status = task_handler_(task_spec, resource_ids, &return_objects); - if (status.IsSystemExit()) { - // In Python, SystemExit can only be raised on the main thread. To work - // around this when we are executing tasks on worker threads, we re-post the - // exit event explicitly on the main thread. - task_main_io_service_.post([this]() { exit_handler_(); }); - return; - } - RAY_CHECK(return_objects.size() == num_returns) - << return_objects.size() << " " << num_returns; + // TODO(edoakes): resource IDs are currently kept track of in the + // raylet, need to come up with a solution for this. + ResourceMappingType resource_ids; + std::vector> return_objects; + auto status = task_handler_(task_spec, resource_ids, &return_objects); + if (status.IsSystemExit()) { + // In Python, SystemExit can only be raised on the main thread. To + // work around this when we are executing tasks on worker threads, + // we re-post the exit event explicitly on the main thread. + task_main_io_service_.post([this]() { exit_handler_(); }); + return; + } + RAY_CHECK(return_objects.size() == num_returns) + << return_objects.size() << " " << num_returns; - for (size_t i = 0; i < return_objects.size(); i++) { - auto return_object = reply->add_return_objects(); - ObjectID id = ObjectID::ForTaskReturn( - task_spec.TaskId(), /*index=*/i + 1, - /*transport_type=*/static_cast(TaskTransportType::DIRECT_ACTOR)); - return_object->set_object_id(id.Binary()); + for (size_t i = 0; i < return_objects.size(); i++) { + auto return_object = reply->add_return_objects(); + ObjectID id = ObjectID::ForTaskReturn( + task_spec.TaskId(), /*index=*/i + 1, + /*transport_type=*/static_cast(TaskTransportType::DIRECT)); + return_object->set_object_id(id.Binary()); - // The object is nullptr if it already existed in the object store. - const auto &result = return_objects[i]; - if (result == nullptr || result->GetData()->IsPlasmaBuffer()) { - return_object->set_in_plasma(true); - } else { - if (result->GetData() != nullptr) { - return_object->set_data(result->GetData()->Data(), - result->GetData()->Size()); - } - if (result->GetMetadata() != nullptr) { - return_object->set_metadata(result->GetMetadata()->Data(), - result->GetMetadata()->Size()); - } - } - } + // The object is nullptr if it already existed in the object store. + const auto &result = return_objects[i]; + if (result == nullptr || result->GetData()->IsPlasmaBuffer()) { + return_object->set_in_plasma(true); + } else { + if (result->GetData() != nullptr) { + return_object->set_data(result->GetData()->Data(), + result->GetData()->Size()); + } + if (result->GetMetadata() != nullptr) { + return_object->set_metadata(result->GetMetadata()->Data(), + result->GetMetadata()->Size()); + } + } + } - send_reply_callback(status, nullptr, nullptr); - }, - [send_reply_callback]() { - send_reply_callback(Status::Invalid("client cancelled rpc"), nullptr, nullptr); - }, - dependencies); + send_reply_callback(status, nullptr, nullptr); + }, + [send_reply_callback]() { + send_reply_callback(Status::Invalid("client cancelled rpc"), nullptr, + nullptr); + }, + dependencies); } -void CoreWorkerDirectActorTaskReceiver::HandleDirectActorCallArgWaitComplete( +void CoreWorkerDirectTaskReceiver::HandleDirectActorCallArgWaitComplete( const rpc::DirectActorCallArgWaitCompleteRequest &request, rpc::DirectActorCallArgWaitCompleteReply *reply, rpc::SendReplyCallback send_reply_callback) { diff --git a/src/ray/core_worker/transport/direct_actor_transport.h b/src/ray/core_worker/transport/direct_actor_transport.h index 6192a2eb1..076e63a6e 100644 --- a/src/ray/core_worker/transport/direct_actor_transport.h +++ b/src/ray/core_worker/transport/direct_actor_transport.h @@ -15,7 +15,7 @@ #include "ray/core_worker/store_provider/memory_store_provider.h" #include "ray/gcs/redis_gcs_client.h" #include "ray/rpc/grpc_server.h" -#include "ray/rpc/worker/worker_client.h" +#include "ray/rpc/worker/core_worker_client.h" namespace ray { @@ -40,15 +40,14 @@ struct ActorStateData { // This class is thread-safe. class CoreWorkerDirectActorTaskSubmitter { public: - CoreWorkerDirectActorTaskSubmitter( - boost::asio::io_service &io_service, - std::unique_ptr store_provider); + CoreWorkerDirectActorTaskSubmitter(rpc::ClientCallManager &client_call_manager, + CoreWorkerMemoryStoreProvider store_provider); /// Submit a task to an actor for execution. /// /// \param[in] task The task spec to submit. /// \return Status::Invalid if the task is not yet supported. - Status SubmitTask(const TaskSpecification &task_spec); + Status SubmitTask(TaskSpecification task_spec); /// Handle an update about an actor. /// @@ -67,10 +66,9 @@ class CoreWorkerDirectActorTaskSubmitter { /// \param[in] task_id The ID of a task. /// \param[in] num_returns Number of return objects. /// \return Void. - void DirectActorAssignTask(rpc::WorkerTaskClient &client, - std::unique_ptr request, - const ActorID &actor_id, const TaskID &task_id, - int num_returns); + void PushActorTask(rpc::CoreWorkerClient &client, + std::unique_ptr request, + const ActorID &actor_id, const TaskID &task_id, int num_returns); /// Treat a task as failed. /// @@ -98,11 +96,8 @@ class CoreWorkerDirectActorTaskSubmitter { /// \return Whether this actor is alive. bool IsActorAlive(const ActorID &actor_id) const; - /// The IO event loop. - boost::asio::io_service &io_service_; - - /// The `ClientCallManager` object that is shared by all `DirectActorClient`s. - rpc::ClientCallManager client_call_manager_; + /// The shared `ClientCallManager` object. + rpc::ClientCallManager &client_call_manager_; /// Mutex to proect the various maps below. mutable std::mutex mutex_; @@ -115,18 +110,17 @@ class CoreWorkerDirectActorTaskSubmitter { /// /// TODO(zhijunfu): this will be moved into `actor_states_` later when we can /// subscribe updates for a specific actor. - std::unordered_map> rpc_clients_; + std::unordered_map> rpc_clients_; /// Map from actor id to the actor's pending requests. - std::unordered_map>> + std::unordered_map>> pending_requests_; /// Map from actor id to the tasks that are waiting for reply. std::unordered_map> waiting_reply_tasks_; /// The store provider. - std::unique_ptr in_memory_store_; + CoreWorkerMemoryStoreProvider in_memory_store_; friend class CoreWorkerTest; }; @@ -235,6 +229,9 @@ class SchedulingQueue { void Add(int64_t seq_no, int64_t client_processed_up_to, std::function accept_request, std::function reject_request, const std::vector &dependencies = {}) { + if (seq_no == -1) { + seq_no = next_seq_no_; // A value of -1 means no ordering constraint. + } RAY_CHECK(boost::this_thread::get_id() == main_thread_id_); if (client_processed_up_to >= next_seq_no_) { RAY_LOG(ERROR) << "client skipping requests " << next_seq_no_ << " to " @@ -329,29 +326,27 @@ class SchedulingQueue { friend class SchedulingQueueTest; }; -class CoreWorkerDirectActorTaskReceiver { +class CoreWorkerDirectTaskReceiver { public: using TaskHandler = std::function> *return_objects)>; - CoreWorkerDirectActorTaskReceiver(WorkerContext &worker_context, - boost::asio::io_service &main_io_service, - rpc::GrpcServer &server, - const TaskHandler &task_handler, - const std::function &exit_handler); + CoreWorkerDirectTaskReceiver(WorkerContext &worker_context, + boost::asio::io_service &main_io_service, + const TaskHandler &task_handler, + const std::function &exit_handler); /// Initialize this receiver. This must be called prior to use. void Init(RayletClient &client); - /// Handle a `DirectActorAssignTask` request. + /// Handle a `PushTask` request. /// /// \param[in] request The request message. /// \param[out] reply The reply message. /// \param[in] send_reply_callback The callback to be called when the request is done. - void HandleDirectActorAssignTask(const rpc::DirectActorAssignTaskRequest &request, - rpc::DirectActorAssignTaskReply *reply, - rpc::SendReplyCallback send_reply_callback); + void HandlePushTask(const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply, + rpc::SendReplyCallback send_reply_callback); /// Handle a `DirectActorCallArgWaitComplete` request. /// diff --git a/src/ray/core_worker/transport/direct_task_transport.cc b/src/ray/core_worker/transport/direct_task_transport.cc new file mode 100644 index 000000000..c05818e18 --- /dev/null +++ b/src/ray/core_worker/transport/direct_task_transport.cc @@ -0,0 +1,186 @@ +#include "ray/core_worker/transport/direct_task_transport.h" + +namespace ray { + +void DoInlineObjectValue(const ObjectID &obj_id, std::shared_ptr value, + TaskSpecification &task) { + auto &msg = task.GetMutableMessage(); + bool found = false; + for (size_t i = 0; i < task.NumArgs(); i++) { + auto count = task.ArgIdCount(i); + if (count > 0) { + const auto &id = task.ArgId(i, 0); + if (id == obj_id) { + auto *mutable_arg = msg.mutable_args(i); + mutable_arg->clear_object_ids(); + if (value->HasData()) { + const auto &data = value->GetData(); + mutable_arg->set_data(data->Data(), data->Size()); + } + if (value->HasMetadata()) { + const auto &metadata = value->GetMetadata(); + mutable_arg->set_metadata(metadata->Data(), metadata->Size()); + } + found = true; + } + } + } + RAY_CHECK(found) << "obj id " << obj_id << " not found"; +} + +void LocalDependencyResolver::ResolveDependencies(const TaskSpecification &task, + std::function on_complete) { + absl::flat_hash_set local_dependencies; + for (size_t i = 0; i < task.NumArgs(); i++) { + auto count = task.ArgIdCount(i); + if (count > 0) { + RAY_CHECK(count <= 1) << "multi args not implemented"; + const auto &id = task.ArgId(i, 0); + if (id.IsDirectCallType()) { + local_dependencies.insert(id); + } + } + } + if (local_dependencies.empty()) { + on_complete(); + return; + } + + // This is deleted when the last dependency fetch callback finishes. + std::shared_ptr state = + std::shared_ptr(new TaskState{task, std::move(local_dependencies)}); + num_pending_ += 1; + + for (const auto &obj_id : state->local_dependencies) { + in_memory_store_.GetAsync( + obj_id, [this, state, obj_id, on_complete](std::shared_ptr obj) { + RAY_CHECK(obj != nullptr); + bool complete = false; + { + absl::MutexLock lock(&mu_); + state->local_dependencies.erase(obj_id); + DoInlineObjectValue(obj_id, obj, state->task); + if (state->local_dependencies.empty()) { + complete = true; + num_pending_ -= 1; + } + } + if (complete) { + on_complete(); + } + }); + } +} + +Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) { + resolver_.ResolveDependencies(task_spec, [this, task_spec]() { + // TODO(ekl) should have a queue per distinct resource type required + absl::MutexLock lock(&mu_); + RequestNewWorkerIfNeeded(task_spec); + queued_tasks_.push_back(task_spec); + // The task is now queued and will be picked up by the next leased or newly + // idle worker. We are guaranteed a worker will show up since we called + // RequestNewWorkerIfNeeded() earlier while holding mu_. + }); + return Status::OK(); +} + +void CoreWorkerDirectTaskSubmitter::HandleWorkerLeaseGranted(const WorkerAddress addr) { + // Setup client state for this worker. + { + absl::MutexLock lock(&mu_); + worker_request_pending_ = false; + + auto it = client_cache_.find(addr); + if (it == client_cache_.end()) { + client_cache_[addr] = + std::shared_ptr(client_factory_(addr)); + RAY_LOG(INFO) << "Connected to " << addr.first << ":" << addr.second; + } + } + + // Try to assign it work. + OnWorkerIdle(addr, /*error=*/false); +} + +void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(const WorkerAddress &addr, + bool was_error) { + absl::MutexLock lock(&mu_); + if (queued_tasks_.empty() || was_error) { + RAY_CHECK_OK(lease_client_.ReturnWorker(addr.second)); + } else { + auto &client = *client_cache_[addr]; + PushNormalTask(addr, client, queued_tasks_.front()); + queued_tasks_.pop_front(); + } + // We have a queue of tasks, try to request more workers. + if (!queued_tasks_.empty()) { + RequestNewWorkerIfNeeded(queued_tasks_.front()); + } +} + +void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded( + const TaskSpecification &resource_spec) { + if (worker_request_pending_) { + return; + } + RAY_CHECK_OK(lease_client_.RequestWorkerLease(resource_spec)); + worker_request_pending_ = true; +} + +void CoreWorkerDirectTaskSubmitter::TreatTaskAsFailed(const TaskID &task_id, + int num_returns, + const rpc::ErrorType &error_type) { + RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id + << ", error_type: " << ErrorType_Name(error_type); + for (int i = 0; i < num_returns; i++) { + const auto object_id = ObjectID::ForTaskReturn( + task_id, /*index=*/i + 1, + /*transport_type=*/static_cast(TaskTransportType::DIRECT)); + std::string meta = std::to_string(static_cast(error_type)); + auto metadata = const_cast(reinterpret_cast(meta.data())); + auto meta_buffer = std::make_shared(metadata, meta.size()); + RAY_CHECK_OK(in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id)); + } +} + +// TODO(ekl) consider reconsolidating with DirectActorTransport. +void CoreWorkerDirectTaskSubmitter::PushNormalTask(const WorkerAddress &addr, + rpc::CoreWorkerClientInterface &client, + TaskSpecification &task_spec) { + auto task_id = task_spec.TaskId(); + auto num_returns = task_spec.NumReturns(); + auto request = std::unique_ptr(new rpc::PushTaskRequest); + request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage()); + auto status = client.PushNormalTask( + std::move(request), + [this, task_id, num_returns, addr](Status status, const rpc::PushTaskReply &reply) { + OnWorkerIdle(addr, /*error=*/!status.ok()); + if (!status.ok()) { + TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED); + return; + } + for (int i = 0; i < reply.return_objects_size(); i++) { + const auto &return_object = reply.return_objects(i); + ObjectID object_id = ObjectID::FromBinary(return_object.object_id()); + std::shared_ptr data_buffer; + if (return_object.data().size() > 0) { + data_buffer = std::make_shared( + const_cast( + reinterpret_cast(return_object.data().data())), + return_object.data().size()); + } + std::shared_ptr metadata_buffer; + if (return_object.metadata().size() > 0) { + metadata_buffer = std::make_shared( + const_cast( + reinterpret_cast(return_object.metadata().data())), + return_object.metadata().size()); + } + RAY_CHECK_OK( + in_memory_store_.Put(RayObject(data_buffer, metadata_buffer), object_id)); + } + }); + RAY_CHECK_OK(status); +} +}; // namespace ray diff --git a/src/ray/core_worker/transport/direct_task_transport.h b/src/ray/core_worker/transport/direct_task_transport.h new file mode 100644 index 000000000..81547492d --- /dev/null +++ b/src/ray/core_worker/transport/direct_task_transport.h @@ -0,0 +1,128 @@ +#ifndef RAY_CORE_WORKER_DIRECT_TASK_H +#define RAY_CORE_WORKER_DIRECT_TASK_H + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "ray/common/id.h" +#include "ray/common/ray_object.h" +#include "ray/core_worker/context.h" +#include "ray/core_worker/store_provider/memory_store_provider.h" +#include "ray/core_worker/transport/direct_actor_transport.h" +#include "ray/raylet/raylet_client.h" +#include "ray/rpc/worker/core_worker_client.h" + +namespace ray { + +struct TaskState { + /// The task to be run. + TaskSpecification task; + /// The remaining dependencies to resolve for this task. + absl::flat_hash_set local_dependencies; +}; + +// This class is thread-safe. +class LocalDependencyResolver { + public: + LocalDependencyResolver(CoreWorkerMemoryStoreProvider &store_provider) + : in_memory_store_(store_provider), num_pending_(0) {} + + /// Resolve all local and remote dependencies for the task, calling the specified + /// callback when done. Direct call ids in the task specification will be resolved + /// to concrete values and inlined. + // + /// Note: This method **will mutate** the given TaskSpecification. + /// + /// Postcondition: all direct call ids in arguments are converted to values. + void ResolveDependencies(const TaskSpecification &task, + std::function on_complete); + + /// Return the number of tasks pending dependency resolution. + /// TODO(ekl) this should be exposed in worker stats. + int NumPendingTasks() const { return num_pending_; } + + private: + /// The store provider. + CoreWorkerMemoryStoreProvider in_memory_store_; + + /// Number of tasks pending dependency resolution. + std::atomic num_pending_; + + /// Protects against concurrent access to internal state. + absl::Mutex mu_; +}; + +typedef std::pair WorkerAddress; +typedef std::function(WorkerAddress)> + ClientFactoryFn; + +// This class is thread-safe. +class CoreWorkerDirectTaskSubmitter { + public: + CoreWorkerDirectTaskSubmitter(WorkerLeaseInterface &lease_client, + ClientFactoryFn client_factory, + CoreWorkerMemoryStoreProvider store_provider) + : lease_client_(lease_client), + client_factory_(client_factory), + in_memory_store_(store_provider), + resolver_(in_memory_store_) {} + + /// Schedule a task for direct submission to a worker. + /// + /// \param[in] task_spec The task to schedule. + Status SubmitTask(TaskSpecification task_spec); + + /// Callback for when the raylet grants us a worker lease. The worker is returned + /// to the raylet once it finishes its task and either the lease term has + /// expired, or there is no more work it can take on. + /// + /// \param[in] addr The (addr, port) pair identifying the worker. + void HandleWorkerLeaseGranted(const WorkerAddress addr); + + private: + /// Schedule more work onto an idle worker or return it back to the raylet if + /// no more tasks are queued for submission. If an error was encountered + /// processing the worker, we don't attempt to re-use the worker. + void OnWorkerIdle(const WorkerAddress &addr, bool was_error); + + /// Request a new worker from the raylet if no such requests are currently in + /// flight. + void RequestNewWorkerIfNeeded(const TaskSpecification &resource_spec) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + /// Push a task to a specific worker. + void PushNormalTask(const WorkerAddress &addr, rpc::CoreWorkerClientInterface &client, + TaskSpecification &task_spec); + + /// Mark a direct call as failed by storing errors for its return objects. + void TreatTaskAsFailed(const TaskID &task_id, int num_returns, + const rpc::ErrorType &error_type); + + // Client that can be used to lease and return workers. + WorkerLeaseInterface &lease_client_; + + /// Factory for producing new core worker clients. + ClientFactoryFn client_factory_; + + /// The store provider. + CoreWorkerMemoryStoreProvider in_memory_store_; + + /// Resolve local and remote dependencies; + LocalDependencyResolver resolver_; + + // Protects task submission state below. + absl::Mutex mu_; + + /// Cache of gRPC clients to other workers. + absl::flat_hash_map> + client_cache_ GUARDED_BY(mu_); + + // Whether we have a request to the Raylet to acquire a new worker in flight. + bool worker_request_pending_ GUARDED_BY(mu_) = false; + + // Tasks that are queued for execution in this submitter.. + std::deque queued_tasks_ GUARDED_BY(mu_); +}; + +}; // namespace ray + +#endif // RAY_CORE_WORKER_DIRECT_TASK_H diff --git a/src/ray/core_worker/transport/raylet_transport.h b/src/ray/core_worker/transport/raylet_transport.h index e218e9cd9..3d05abe53 100644 --- a/src/ray/core_worker/transport/raylet_transport.h +++ b/src/ray/core_worker/transport/raylet_transport.h @@ -5,7 +5,7 @@ #include "ray/common/ray_object.h" #include "ray/raylet/raylet_client.h" -#include "ray/rpc/worker/worker_server.h" +#include "ray/rpc/worker/core_worker_server.h" namespace ray { diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index e23c40afd..c2439dfbc 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -66,6 +66,8 @@ message TaskSpec { // Task specification for an actor task. // This field is only valid when `type == ACTOR_TASK`. ActorTaskSpec actor_task_spec = 14; + // Whether this task is a direct call task. + bool is_direct_call = 15; } // Argument in the task. diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index d48f68eb6..2ad4b2623 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -56,22 +56,24 @@ message ReturnObject { bytes metadata = 4; } -message DirectActorAssignTaskRequest { +message PushTaskRequest { // The task to be pushed. TaskSpec task_spec = 1; // The sequence number of the task for this client. This must increase // sequentially starting from zero for each actor handle. The server // will guarantee tasks execute in this sequence, waiting for any // out-of-order request messages to arrive as necessary. + // If set to -1, ordering is disabled and the task executes immediately. + // This mode of behaviour is used for direct task submission only. int64 sequence_number = 2; // The max sequence number the client has processed responses for. This // is a performance optimization that allows the client to tell the server - // to cancel any DirectActorAssignTaskRequests with seqno <= this value, rather than + // to cancel any PushTaskRequests with seqno <= this value, rather than // waiting for the server to time out waiting for missing messages. int64 client_processed_up_to = 3; } -message DirectActorAssignTaskReply { +message PushTaskReply { // The returned objects. repeated ReturnObject return_objects = 1; } @@ -85,13 +87,24 @@ message DirectActorCallArgWaitCompleteRequest { message DirectActorCallArgWaitCompleteReply { } -service WorkerService { - // Push a task to a worker. +message WorkerLeaseGrantedRequest { + // Address of the leased worker. + string address = 1; + // Port of the leased worker. + int32 port = 2; +} + +message WorkerLeaseGrantedReply { +} + +service CoreWorkerService { + // Push a task to a worker from the raylet. rpc AssignTask(AssignTaskRequest) returns (AssignTaskReply); - // Push a task to a direct-call actor. - rpc DirectActorAssignTask(DirectActorAssignTaskRequest) - returns (DirectActorAssignTaskReply); - // Notify wait for direct actor call args has completed. + // Push a task directly to this worker from another. + rpc PushTask(PushTaskRequest) returns (PushTaskReply); + // Reply from raylet that wait for direct actor call args has completed. rpc DirectActorCallArgWaitComplete(DirectActorCallArgWaitCompleteRequest) returns (DirectActorCallArgWaitCompleteReply); + // Reply from raylet to fulfill a worker lease request. + rpc WorkerLeaseGranted(WorkerLeaseGrantedRequest) returns (WorkerLeaseGrantedReply); } diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 92907079d..525628802 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -74,6 +74,10 @@ enum MessageType:int { SetResourceRequest, // Update the active set of object IDs in use on this worker. ReportActiveObjectIDs, + // Request a worker from the raylet with the specified resources. + RequestWorkerLease, + // Returns a worker to the raylet. + ReturnWorker, } table TaskExecutionSpecification { @@ -91,6 +95,14 @@ table Task { task_execution_spec: TaskExecutionSpecification; } +table WorkerLeaseRequest { + resource_spec: string; +} + +table ReturnWorkerRequest { + worker_port: int; +} + // This message describes a given resource that is reserved for a worker. table ResourceIdSetInfo { // The name of the resource. diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 633505c51..688bc1761 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -139,7 +139,8 @@ static inline Task ExampleTask(const std::vector &arguments, uint64_t num_returns) { TaskSpecBuilder builder; builder.SetCommonTaskSpec(RandomTaskId(), Language::PYTHON, {"", "", ""}, JobID::Nil(), - RandomTaskId(), 0, RandomTaskId(), num_returns, {}, {}); + RandomTaskId(), 0, RandomTaskId(), num_returns, false, {}, + {}); for (const auto &arg : arguments) { builder.AddByRefArg(arg); } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 5386bbe56..549f6d97b 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -2,6 +2,7 @@ #include #include +#include #include "ray/common/status.h" @@ -895,6 +896,12 @@ void NodeManager::ProcessClientMessage( // because it's already disconnected. return; } break; + case protocol::MessageType::RequestWorkerLease: { + ProcessRequestWorkerLeaseMessage(client, message_data); + } break; + case protocol::MessageType::ReturnWorker: { + ProcessReturnWorkerMessage(message_data); + } break; case protocol::MessageType::SetResourceRequest: { ProcessSetResourceRequest(client, message_data); } break; @@ -1031,6 +1038,10 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca void NodeManager::HandleWorkerAvailable( const std::shared_ptr &client) { std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + HandleWorkerAvailable(worker); +} + +void NodeManager::HandleWorkerAvailable(const std::shared_ptr &worker) { RAY_CHECK(worker); // If the worker was assigned a task, mark it as finished. if (!worker->GetAssignedTaskId().IsNil()) { @@ -1172,6 +1183,43 @@ void NodeManager::ProcessDisconnectClientMessage( // these can be leaked. } +void NodeManager::ProcessRequestWorkerLeaseMessage( + const std::shared_ptr &client, const uint8_t *message_data) { + // Read the resource spec submitted by the client. + auto fbs_message = flatbuffers::GetRoot(message_data); + rpc::Task task_message; + RAY_CHECK(task_message.mutable_task_spec()->ParseFromArray( + fbs_message->resource_spec()->data(), fbs_message->resource_spec()->size())); + + // Override the task dispatch to call back to the client instead of executing the + // task directly on the worker. TODO(ekl) handle spilling case + Task task(task_message); + task.OnDispatchInstead([this, client](const std::shared_ptr granted, + const std::string &address, int port) { + std::shared_ptr client_worker = worker_pool_.GetRegisteredWorker(client); + if (client_worker == nullptr) { + client_worker = worker_pool_.GetRegisteredDriver(client); + } + if (client_worker == nullptr) { + RAY_LOG(FATAL) << "TODO: Lost worker for lease request " << client; + } else { + client_worker->WorkerLeaseGranted(address, port); + leased_workers_[port] = std::static_pointer_cast(granted); + } + }); + SubmitTask(task, Lineage()); +} + +void NodeManager::ProcessReturnWorkerMessage(const uint8_t *message_data) { + // Read the resource spec submitted by the client. + auto fbs_message = flatbuffers::GetRoot(message_data); + auto worker_port = fbs_message->worker_port(); + RAY_LOG(DEBUG) << "Return worker " << worker_port; + std::shared_ptr worker = leased_workers_[worker_port]; + leased_workers_.erase(worker_port); + HandleWorkerAvailable(worker); +} + void NodeManager::ProcessFetchOrReconstructMessage( const std::shared_ptr &client, const uint8_t *message_data) { auto message = flatbuffers::GetRoot(message_data); @@ -1955,7 +2003,12 @@ bool NodeManager::AssignTask(const Task &task) { ResourceIdSet resource_id_set = worker->GetTaskResourceIds().Plus(worker->GetLifetimeResourceIds()); - worker->AssignTask(task, resource_id_set, finish_assign_task_callback); + if (task.OnDispatch() != nullptr) { + task.OnDispatch()(worker, initial_config_.node_manager_address, worker->Port()); + finish_assign_task_callback(Status::OK()); + } else { + worker->AssignTask(task, resource_id_set, finish_assign_task_callback); + } // We assigned this task to a worker. // (Note this means that we sent the task to the worker. The assignment diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index f44c55202..84f331552 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -394,6 +394,12 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \return Void. void HandleWorkerAvailable(const std::shared_ptr &client); + /// Handle the case that a worker is available. + /// + /// \param worker The pointer to the worker + /// \return Void. + void HandleWorkerAvailable(const std::shared_ptr &worker); + /// Handle a client that has disconnected. This can be called multiple times /// on the same client because this is triggered both when a client /// disconnects and when the node manager fails to write a message to the @@ -406,6 +412,20 @@ class NodeManager : public rpc::NodeManagerServiceHandler { const std::shared_ptr &client, bool intentional_disconnect = false); + /// Process client message of RequestWorkerLease + /// + /// \param client The client that sent the message. + /// \param message_data A pointer to the message data. + /// \return Void. + void ProcessRequestWorkerLeaseMessage( + const std::shared_ptr &client, const uint8_t *message_data); + + /// Process client message of ReturnWorkerMessage + /// + /// \param message_data A pointer to the message data. + /// \return Void. + void ProcessReturnWorkerMessage(const uint8_t *message_data); + /// Process client message of FetchOrReconstruct /// /// \param client The client that sent the message. @@ -574,12 +594,15 @@ class NodeManager : public rpc::NodeManagerServiceHandler { rpc::NodeManagerGrpcService node_manager_service_; /// The `ClientCallManager` object that is shared by all `NodeManagerClient`s - /// as well as all `WorkerTaskClient`s. + /// as well as all `CoreWorkerClient`s. rpc::ClientCallManager client_call_manager_; /// Map from node ids to clients of the remote node managers. std::unordered_map> remote_node_manager_clients_; + + /// Map of workers leased out to direct call clients. + std::unordered_map> leased_workers_; }; } // namespace raylet diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 6e1b34e14..429910e99 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -374,3 +374,19 @@ ray::Status RayletClient::ReportActiveObjectIDs( return conn_->WriteMessage(MessageType::ReportActiveObjectIDs, &fbb); } + +ray::Status RayletClient::RequestWorkerLease( + const ray::TaskSpecification &resource_spec) { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateWorkerLeaseRequest( + fbb, fbb.CreateString(resource_spec.Serialize())); + fbb.Finish(message); + return conn_->WriteMessage(MessageType::RequestWorkerLease, &fbb); +} + +ray::Status RayletClient::ReturnWorker(int worker_port) { + flatbuffers::FlatBufferBuilder fbb; + auto message = ray::protocol::CreateReturnWorkerRequest(fbb, worker_port); + fbb.Finish(message); + return conn_->WriteMessage(MessageType::ReturnWorker, &fbb); +} diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index c9dac0e42..38b46d56f 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -63,7 +63,21 @@ class RayletConnection { std::mutex write_mutex_; }; -class RayletClient { +/// Interface for leasing workers. Abstract for testing. +class WorkerLeaseInterface { + public: + /// Requests a worker from the raylet. The callback will be sent via gRPC. + /// \param resource_spec Resources that should be allocated for the worker. + /// \return ray::Status + virtual ray::Status RequestWorkerLease(const ray::TaskSpecification &resource_spec) = 0; + + /// Returns a worker to the raylet. + /// \param worker_port The local port of the worker on the raylet node. + /// \return ray::Status + virtual ray::Status ReturnWorker(int worker_port) = 0; +}; + +class RayletClient : public WorkerLeaseInterface { public: /// Connect to the raylet. /// @@ -185,6 +199,12 @@ class RayletClient { /// \return ray::Status ray::Status ReportActiveObjectIDs(const std::unordered_set &object_ids); + /// Implements WorkerLeaseInterface. + ray::Status RequestWorkerLease(const ray::TaskSpecification &resource_spec) override; + + /// Implements WorkerLeaseInterface. + ray::Status ReturnWorker(int worker_port) override; + Language GetLanguage() const { return language_; } WorkerID GetWorkerID() const { return worker_id_; } diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index aae322454..2ae727434 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -76,7 +76,8 @@ static inline Task ExampleTask(const std::vector &arguments, uint64_t num_returns) { TaskSpecBuilder builder; builder.SetCommonTaskSpec(RandomTaskId(), Language::PYTHON, {"", "", ""}, JobID::Nil(), - RandomTaskId(), 0, RandomTaskId(), num_returns, {}, {}); + RandomTaskId(), 0, RandomTaskId(), num_returns, false, {}, + {}); for (const auto &arg : arguments) { builder.AddByRefArg(arg); } diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index df6535098..78e6c4a04 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -25,8 +25,8 @@ Worker::Worker(const WorkerID &worker_id, pid_t pid, const Language &language, i client_call_manager_(client_call_manager), is_detached_actor_(false) { if (port_ > 0) { - rpc_client_ = std::unique_ptr( - new rpc::WorkerTaskClient("127.0.0.1", port_, client_call_manager_)); + rpc_client_ = std::unique_ptr( + new rpc::CoreWorkerClient("127.0.0.1", port_, client_call_manager_)); } } @@ -172,6 +172,24 @@ void Worker::DirectActorCallArgWaitComplete(int64_t tag) { } } +void Worker::WorkerLeaseGranted(const std::string &address, int port) { + RAY_CHECK(!address.empty()); + RAY_CHECK(port_ > 0); + rpc::WorkerLeaseGrantedRequest request; + request.set_address(address); + request.set_port(port); + auto status = rpc_client_->WorkerLeaseGranted( + request, [address, port](Status status, const rpc::WorkerLeaseGrantedReply &reply) { + if (!status.ok()) { + RAY_LOG(ERROR) << "Failed to reply to lease request: " << status.ToString() + << " for " << address << ":" << port; + } + }); + if (!status.ok()) { + RAY_LOG(ERROR) << "Failed to reply to lease request: " << status.ToString(); + } +} + } // namespace raylet } // end namespace ray diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index a67b3a76e..a162dc710 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -8,7 +8,7 @@ #include "ray/common/task/scheduling_resources.h" #include "ray/common/task/task.h" #include "ray/common/task/task_common.h" -#include "ray/rpc/worker/worker_client.h" +#include "ray/rpc/worker/core_worker_client.h" #include // pid_t @@ -67,6 +67,7 @@ class Worker { void AssignTask(const Task &task, const ResourceIdSet &resource_id_set, const std::function finish_assign_callback); void DirectActorCallArgWaitComplete(int64_t tag); + void WorkerLeaseGranted(const std::string &address, int port); private: /// The worker's ID. @@ -100,11 +101,11 @@ class Worker { std::unordered_set blocked_task_ids_; /// The set of object IDs that are currently in use on the worker. std::unordered_set active_object_ids_; - /// The `ClientCallManager` object that is shared by `WorkerTaskClient` from all + /// The `ClientCallManager` object that is shared by `CoreWorkerClient` from all /// workers. rpc::ClientCallManager &client_call_manager_; /// The rpc client to send tasks to this worker. - std::unique_ptr rpc_client_; + std::unique_ptr rpc_client_; /// Whether the worker is detached. This is applies when the worker is actor. /// Detached actor means the actor's creator can exit without killing this actor. bool is_detached_actor_; diff --git a/src/ray/rpc/worker/worker_client.h b/src/ray/rpc/worker/core_worker_client.h similarity index 53% rename from src/ray/rpc/worker/worker_client.h rename to src/ray/rpc/worker/core_worker_client.h index 29d0bf3c2..4cd74af1f 100644 --- a/src/ray/rpc/worker/worker_client.h +++ b/src/ray/rpc/worker/core_worker_client.h @@ -1,5 +1,5 @@ -#ifndef RAY_RPC_WORKER_CLIENT_H -#define RAY_RPC_WORKER_CLIENT_H +#ifndef RAY_RPC_CORE_WORKER_CLIENT_H +#define RAY_RPC_CORE_WORKER_CLIENT_H #include #include @@ -25,7 +25,7 @@ const int64_t kMaxBytesInFlight = 16 * 1024 * 1024; const int64_t kBaseRequestSize = 1024; /// Get the estimated size in bytes of the given task. -const static int64_t RequestSizeInBytes(const DirectActorAssignTaskRequest &request) { +const static int64_t RequestSizeInBytes(const PushTaskRequest &request) { int64_t size = kBaseRequestSize; for (auto &arg : request.task_spec().args()) { size += arg.data().size(); @@ -33,44 +33,87 @@ const static int64_t RequestSizeInBytes(const DirectActorAssignTaskRequest &requ return size; } +/// Abstract client interface for testing. +class CoreWorkerClientInterface { + public: + /// This is called by the Raylet to assign a task to the worker. + /// + /// \param[in] request The request message. + /// \param[in] callback The callback function that handles reply. + /// \return if the rpc call succeeds + virtual ray::Status AssignTask(const AssignTaskRequest &request, + const ClientCallback &callback) { + return Status::NotImplemented(""); + } + + /// Push an actor task directly from worker to worker. + /// + /// \param[in] request The request message. + /// \param[in] callback The callback function that handles reply. + /// \return if the rpc call succeeds + virtual ray::Status PushActorTask(std::unique_ptr request, + const ClientCallback &callback) { + return Status::NotImplemented(""); + } + + /// Similar to PushActorTask, but sets no ordering constraint. This is used to + /// push non-actor tasks directly to a worker. + virtual ray::Status PushNormalTask(std::unique_ptr request, + const ClientCallback &callback) { + return Status::NotImplemented(""); + } + + /// Notify a wait has completed for direct actor call arguments. + /// + /// \param[in] request The request message. + /// \param[in] callback The callback function that handles reply. + /// \return if the rpc call succeeds + virtual ray::Status DirectActorCallArgWaitComplete( + const DirectActorCallArgWaitCompleteRequest &request, + const ClientCallback &callback) { + return Status::NotImplemented(""); + } + + /// Grants a worker to the client. + /// + /// \param[in] request The request message. + /// \param[in] callback The callback function that handles reply. + /// \return if the rpc call succeeds + virtual ray::Status WorkerLeaseGranted( + const WorkerLeaseGrantedRequest &request, + const ClientCallback &callback) { + return Status::NotImplemented(""); + } +}; + /// Client used for communicating with a remote worker server. -class WorkerTaskClient : public std::enable_shared_from_this { +class CoreWorkerClient : public std::enable_shared_from_this, + public CoreWorkerClientInterface { public: /// Constructor. /// /// \param[in] address Address of the worker server. /// \param[in] port Port of the worker server. /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. - WorkerTaskClient(const std::string &address, const int port, + CoreWorkerClient(const std::string &address, const int port, ClientCallManager &client_call_manager) : client_call_manager_(client_call_manager) { std::shared_ptr channel = grpc::CreateChannel( address + ":" + std::to_string(port), grpc::InsecureChannelCredentials()); - stub_ = WorkerService::NewStub(channel); + stub_ = CoreWorkerService::NewStub(channel); }; - /// Assign a task to the work. - /// - /// \param[in] request The request message. - /// \param[in] callback The callback function that handles reply. - /// \return if the rpc call succeeds ray::Status AssignTask(const AssignTaskRequest &request, - const ClientCallback &callback) { - auto call = - client_call_manager_ - .CreateCall( - *stub_, &WorkerService::Stub::PrepareAsyncAssignTask, request, callback); + const ClientCallback &callback) override { + auto call = client_call_manager_ + .CreateCall( + *stub_, &CoreWorkerService::Stub::PrepareAsyncAssignTask, request, + callback); return call->GetStatus(); } - /// Push a task. - /// - /// \param[in] request The request message. - /// \param[in] callback The callback function that handles reply. - /// \return if the rpc call succeeds - ray::Status DirectActorAssignTask( - std::unique_ptr request, - const ClientCallback &callback) { + ray::Status PushActorTask(std::unique_ptr request, + const ClientCallback &callback) override { request->set_sequence_number(request->task_spec().actor_task_spec().actor_counter()); { std::lock_guard lock(mutex_); @@ -85,20 +128,36 @@ class WorkerTaskClient : public std::enable_shared_from_this { return ray::Status::OK(); } - /// Notify a wait has completed for direct actor call arguments. - /// - /// \param[in] request The request message. - /// \param[in] callback The callback function that handles reply. - /// \return if the rpc call succeeds + ray::Status PushNormalTask(std::unique_ptr request, + const ClientCallback &callback) override { + request->set_sequence_number(-1); + request->set_client_processed_up_to(-1); + auto call = client_call_manager_ + .CreateCall( + *stub_, &CoreWorkerService::Stub::PrepareAsyncPushTask, *request, + callback); + return call->GetStatus(); + } + ray::Status DirectActorCallArgWaitComplete( const DirectActorCallArgWaitCompleteRequest &request, - const ClientCallback &callback) { + const ClientCallback &callback) override { + auto call = client_call_manager_.CreateCall( + *stub_, &CoreWorkerService::Stub::PrepareAsyncDirectActorCallArgWaitComplete, + request, callback); + return call->GetStatus(); + } + + ray::Status WorkerLeaseGranted( + const WorkerLeaseGrantedRequest &request, + const ClientCallback &callback) override { auto call = - client_call_manager_ - .CreateCall( - *stub_, &WorkerService::Stub::PrepareAsyncDirectActorCallArgWaitComplete, - request, callback); + client_call_manager_.CreateCall( + *stub_, &CoreWorkerService::Stub::PrepareAsyncWorkerLeaseGranted, request, + callback); return call->GetStatus(); } @@ -122,11 +181,10 @@ class WorkerTaskClient : public std::enable_shared_from_this { request->set_client_processed_up_to(max_finished_seq_no_); rpc_bytes_in_flight_ += task_size; - client_call_manager_.CreateCall( - *stub_, &WorkerService::Stub::PrepareAsyncDirectActorAssignTask, *request, - [this, this_ptr, seq_no, task_size, callback]( - Status status, const rpc::DirectActorAssignTaskReply &reply) { + client_call_manager_.CreateCall( + *stub_, &CoreWorkerService::Stub::PrepareAsyncPushTask, *request, + [this, this_ptr, seq_no, task_size, callback](Status status, + const rpc::PushTaskReply &reply) { { std::lock_guard lock(mutex_); if (seq_no > max_finished_seq_no_) { @@ -150,14 +208,13 @@ class WorkerTaskClient : public std::enable_shared_from_this { std::mutex mutex_; /// The gRPC-generated stub. - std::unique_ptr stub_; + std::unique_ptr stub_; /// The `ClientCallManager` used for managing requests. ClientCallManager &client_call_manager_; /// Queue of requests to send. - std::deque, - ClientCallback>> + std::deque, ClientCallback>> send_queue_ GUARDED_BY(mutex_); /// The number of bytes currently in flight. @@ -174,4 +231,4 @@ class WorkerTaskClient : public std::enable_shared_from_this { } // namespace rpc } // namespace ray -#endif // RAY_RPC_WORKER_CLIENT_H +#endif // RAY_RPC_CORE_WORKER_CLIENT_H diff --git a/src/ray/rpc/worker/worker_server.cc b/src/ray/rpc/worker/core_worker_server.cc similarity index 58% rename from src/ray/rpc/worker/worker_server.cc rename to src/ray/rpc/worker/core_worker_server.cc index 44b3c7f7b..165c9031a 100644 --- a/src/ray/rpc/worker/worker_server.cc +++ b/src/ray/rpc/worker/core_worker_server.cc @@ -1,23 +1,23 @@ -#include "ray/rpc/worker/worker_server.h" +#include "ray/rpc/worker/core_worker_server.h" #include "ray/core_worker/core_worker.h" namespace ray { namespace rpc { -#define RAY_CORE_WORKER_RPC_HANDLER(HANDLER, CONCURRENCY) \ - std::unique_ptr HANDLER##_call_factory( \ - new ServerCallFactoryImpl( \ - service_, &WorkerService::AsyncService::Request##HANDLER, core_worker_, \ - &CoreWorker::Handle##HANDLER, cq, main_service_)); \ - server_call_factories_and_concurrencies->emplace_back( \ +#define RAY_CORE_WORKER_RPC_HANDLER(HANDLER, CONCURRENCY) \ + std::unique_ptr HANDLER##_call_factory( \ + new ServerCallFactoryImpl( \ + service_, &CoreWorkerService::AsyncService::Request##HANDLER, core_worker_, \ + &CoreWorker::Handle##HANDLER, cq, main_service_)); \ + server_call_factories_and_concurrencies->emplace_back( \ std::move(HANDLER##_call_factory), CONCURRENCY); -WorkerGrpcService::WorkerGrpcService(boost::asio::io_service &main_service, - CoreWorker &core_worker) +CoreWorkerGrpcService::CoreWorkerGrpcService(boost::asio::io_service &main_service, + CoreWorker &core_worker) : GrpcService(main_service), core_worker_(core_worker){}; -void WorkerGrpcService::InitServerCallFactories( +void CoreWorkerGrpcService::InitServerCallFactories( const std::unique_ptr &cq, std::vector, int>> *server_call_factories_and_concurrencies) { diff --git a/src/ray/rpc/worker/worker_server.h b/src/ray/rpc/worker/core_worker_server.h similarity index 71% rename from src/ray/rpc/worker/worker_server.h rename to src/ray/rpc/worker/core_worker_server.h index d9f87126a..9e58929c5 100644 --- a/src/ray/rpc/worker/worker_server.h +++ b/src/ray/rpc/worker/core_worker_server.h @@ -1,5 +1,5 @@ -#ifndef RAY_RPC_WORKER_SERVER_H -#define RAY_RPC_WORKER_SERVER_H +#ifndef RAY_RPC_CORE_WORKER_SERVER_H +#define RAY_RPC_CORE_WORKER_SERVER_H #include "ray/rpc/grpc_server.h" #include "ray/rpc/server_call.h" @@ -13,14 +13,14 @@ class CoreWorker; namespace rpc { -/// The `GrpcServer` for `WorkerService`. -class WorkerGrpcService : public GrpcService { +/// The `GrpcServer` for `CoreWorkerService`. +class CoreWorkerGrpcService : public GrpcService { public: /// Constructor. /// /// \param[in] main_service See super class. /// \param[in] handler The service handler that actually handle the requests. - WorkerGrpcService(boost::asio::io_service &main_service, CoreWorker &core_worker); + CoreWorkerGrpcService(boost::asio::io_service &main_service, CoreWorker &core_worker); protected: grpc::Service &GetGrpcService() override { return service_; } @@ -32,7 +32,7 @@ class WorkerGrpcService : public GrpcService { private: /// The grpc async service object. - WorkerService::AsyncService service_; + CoreWorkerService::AsyncService service_; /// The core worker that actually handles the requests. CoreWorker &core_worker_; @@ -41,4 +41,4 @@ class WorkerGrpcService : public GrpcService { } // namespace rpc } // namespace ray -#endif +#endif // RAY_RPC_CORE_WORKER_SERVER_H diff --git a/src/ray/util/test_util.h b/src/ray/util/test_util.h index a89acc7f1..79971bba9 100644 --- a/src/ray/util/test_util.h +++ b/src/ray/util/test_util.h @@ -3,6 +3,8 @@ #include +#include "ray/common/buffer.h" +#include "ray/common/ray_object.h" #include "ray/util/util.h" namespace ray { @@ -40,6 +42,20 @@ inline TaskID RandomTaskId() { return TaskID::FromBinary(data); } +std::shared_ptr GenerateRandomBuffer() { + auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count(); + std::mt19937 gen(seed); + std::uniform_int_distribution<> dis(1, 10); + std::uniform_int_distribution<> value_dis(1, 255); + + std::vector arg1(dis(gen), value_dis(gen)); + return std::make_shared(arg1.data(), arg1.size(), true); +} + +std::shared_ptr GenerateRandomObject() { + return std::shared_ptr(new RayObject(GenerateRandomBuffer(), nullptr)); +} + } // namespace ray #endif // RAY_UTIL_TEST_UTIL_H