mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:48:31 +08:00
Minimal implementation of direct task calls (#6075)
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
/bazel-*
|
||||
/python/ray/core
|
||||
/python/ray/pyarrow_files/
|
||||
/python/ray/pickle5_files/
|
||||
/python/build
|
||||
/python/dist
|
||||
/thirdparty/pkg/
|
||||
|
||||
+10
@@ -404,6 +404,16 @@ cc_binary(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "direct_task_transport_test",
|
||||
srcs = ["src/ray/core_worker/test/direct_task_transport_test.cc"],
|
||||
copts = COPTS,
|
||||
deps = [
|
||||
":core_worker_lib",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "reference_count_test",
|
||||
srcs = ["src/ray/core_worker/reference_count_test.cc"],
|
||||
|
||||
+8
-8
@@ -70,29 +70,29 @@ format_changed() {
|
||||
# could cause yapf to receive 0 positional arguments, making it hang
|
||||
# waiting for STDIN.
|
||||
#
|
||||
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
|
||||
# `diff-filter=ACRM` and $MERGEBASE is to ensure we only format files that
|
||||
# exist on both branches.
|
||||
MERGEBASE="$(git merge-base upstream/master HEAD)"
|
||||
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' | xargs -P 5 \
|
||||
if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.py' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \
|
||||
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
|
||||
if which flake8 >/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' | xargs -P 5 \
|
||||
git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \
|
||||
flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605
|
||||
fi
|
||||
fi
|
||||
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then
|
||||
if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then
|
||||
if which flake8 >/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \
|
||||
git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \
|
||||
flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605
|
||||
fi
|
||||
fi
|
||||
|
||||
if which clang-format >/dev/null; then
|
||||
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.cc' '*.h' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.cc' '*.h' | xargs -P 5 \
|
||||
if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.cc' '*.h' &>/dev/null; then
|
||||
git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.cc' '*.h' | xargs -P 5 \
|
||||
clang-format -i
|
||||
fi
|
||||
fi
|
||||
|
||||
+32
-9
@@ -93,6 +93,7 @@ from ray.ray_constants import (
|
||||
DEFAULT_PUT_OBJECT_DELAY,
|
||||
DEFAULT_PUT_OBJECT_RETRIES,
|
||||
RAW_BUFFER_METADATA,
|
||||
PICKLE_BUFFER_METADATA,
|
||||
PICKLE5_BUFFER_METADATA,
|
||||
)
|
||||
|
||||
@@ -334,6 +335,7 @@ cdef c_vector[c_string] string_vector_from_list(list string_list):
|
||||
cdef void prepare_args(list args, c_vector[CTaskArg] *args_vector):
|
||||
cdef:
|
||||
c_string pickled_str
|
||||
c_string metadata_str = PICKLE_BUFFER_METADATA
|
||||
shared_ptr[CBuffer] arg_data
|
||||
shared_ptr[CBuffer] arg_metadata
|
||||
|
||||
@@ -353,6 +355,11 @@ cdef void prepare_args(list args, c_vector[CTaskArg] *args_vector):
|
||||
<uint8_t*>(pickled_str.data()),
|
||||
pickled_str.size(),
|
||||
True))
|
||||
arg_metadata = dynamic_pointer_cast[
|
||||
CBuffer, LocalMemoryBuffer](
|
||||
make_shared[LocalMemoryBuffer](
|
||||
<uint8_t*>(
|
||||
metadata_str.data()), metadata_str.size(), True))
|
||||
args_vector.push_back(
|
||||
CTaskArg.PassByValue(
|
||||
make_shared[CRayObject](arg_data, arg_metadata)))
|
||||
@@ -436,8 +443,18 @@ cdef deserialize_args(
|
||||
c_args[i].get().GetMetadata()).to_pybytes()
|
||||
== RAW_BUFFER_METADATA):
|
||||
args.append(data)
|
||||
else:
|
||||
elif (c_args[i].get().HasMetadata() and Buffer.make(
|
||||
c_args[i].get().GetMetadata()).to_pybytes()
|
||||
== PICKLE_BUFFER_METADATA):
|
||||
# This is a pickled "simple python value" argument.
|
||||
args.append(pickle.loads(data.to_pybytes()))
|
||||
else:
|
||||
# This is a Ray object inlined by the direct task submitter.
|
||||
by_reference_ids.append(
|
||||
ObjectID(arg_reference_ids[i].Binary()))
|
||||
by_reference_indices.append(i)
|
||||
by_reference_objects.push_back(c_args[i])
|
||||
args.append(None)
|
||||
# Passed by reference.
|
||||
else:
|
||||
by_reference_ids.append(
|
||||
@@ -658,12 +675,14 @@ cdef shared_ptr[CBuffer] string_to_buffer(c_string& c_str):
|
||||
cdef shared_ptr[CBuffer] empty_metadata
|
||||
if c_str.size() == 0:
|
||||
return empty_metadata
|
||||
return dynamic_pointer_cast[CBuffer, LocalMemoryBuffer](
|
||||
make_shared[LocalMemoryBuffer](<uint8_t*>(c_str.data()),
|
||||
c_str.size(), True))
|
||||
return dynamic_pointer_cast[
|
||||
CBuffer, LocalMemoryBuffer](
|
||||
make_shared[LocalMemoryBuffer](
|
||||
<uint8_t*>(c_str.data()), c_str.size(), True))
|
||||
|
||||
|
||||
cdef write_serialized_object(serialized_object, const shared_ptr[CBuffer]& buf):
|
||||
cdef write_serialized_object(
|
||||
serialized_object, const shared_ptr[CBuffer]& buf):
|
||||
# avoid initializing pyarrow before raylet
|
||||
from ray.serialization import Pickle5SerializedObject, RawSerializedObject
|
||||
|
||||
@@ -851,6 +870,7 @@ cdef class CoreWorker:
|
||||
function_descriptor,
|
||||
args,
|
||||
int num_return_vals,
|
||||
c_bool is_direct_call,
|
||||
resources):
|
||||
cdef:
|
||||
unordered_map[c_string, double] c_resources
|
||||
@@ -861,7 +881,8 @@ cdef class CoreWorker:
|
||||
|
||||
with self.profile_event(b"submit_task"):
|
||||
prepare_resources(resources, &c_resources)
|
||||
task_options = CTaskOptions(num_return_vals, c_resources)
|
||||
task_options = CTaskOptions(
|
||||
num_return_vals, is_direct_call, c_resources)
|
||||
ray_function = CRayFunction(
|
||||
LANGUAGE_PYTHON, string_vector_from_list(function_descriptor))
|
||||
prepare_args(args, &args_vector)
|
||||
@@ -925,7 +946,7 @@ cdef class CoreWorker:
|
||||
with self.profile_event(b"submit_task"):
|
||||
if num_method_cpus > 0:
|
||||
c_resources[b"CPU"] = num_method_cpus
|
||||
task_options = CTaskOptions(num_return_vals, c_resources)
|
||||
task_options = CTaskOptions(num_return_vals, False, c_resources)
|
||||
ray_function = CRayFunction(
|
||||
LANGUAGE_PYTHON, string_vector_from_list(function_descriptor))
|
||||
prepare_args(args, &args_vector)
|
||||
@@ -1017,7 +1038,8 @@ cdef class CoreWorker:
|
||||
context = worker.get_serialization_context()
|
||||
serialized_object = context.serialize(output)
|
||||
data_sizes.push_back(serialized_object.total_bytes)
|
||||
metadatas.push_back(string_to_buffer(serialized_object.metadata))
|
||||
metadatas.push_back(
|
||||
string_to_buffer(serialized_object.metadata))
|
||||
serialized_objects.append(serialized_object)
|
||||
|
||||
check_status(self.core_worker.get().AllocateReturnObjects(
|
||||
@@ -1030,4 +1052,5 @@ cdef class CoreWorker:
|
||||
if serialized_object is NoReturn:
|
||||
returns[0][i].reset()
|
||||
else:
|
||||
write_serialized_object(serialized_object, returns[0][i].get().GetData())
|
||||
write_serialized_object(
|
||||
serialized_object, returns[0][i].get().GetData())
|
||||
|
||||
@@ -168,7 +168,9 @@ class Dashboard(object):
|
||||
raise ValueError(
|
||||
"Dashboard static asset directory not found at '{}'. If "
|
||||
"installing from source, please follow the additional steps "
|
||||
"required to build the dashboard.".format(static_dir))
|
||||
"required to build the dashboard: "
|
||||
"cd python/ray/dashboard/client && npm ci && "
|
||||
"npm run build".format(static_dir))
|
||||
self.app.router.add_static("/static", static_dir)
|
||||
|
||||
self.app.router.add_get("/api/ray_config", ray_config)
|
||||
|
||||
@@ -201,7 +201,7 @@ cdef extern from "ray/core_worker/common.h" nogil:
|
||||
|
||||
cdef cppclass CTaskOptions "ray::TaskOptions":
|
||||
CTaskOptions()
|
||||
CTaskOptions(int num_returns,
|
||||
CTaskOptions(int num_returns, c_bool is_direct_call,
|
||||
unordered_map[c_string, double] &resources)
|
||||
|
||||
cdef cppclass CActorCreationOptions "ray::ActorCreationOptions":
|
||||
|
||||
@@ -148,7 +148,7 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
|
||||
|
||||
c_bool is_put()
|
||||
|
||||
c_bool IsDirectActorType()
|
||||
c_bool IsDirectCallType()
|
||||
|
||||
int64_t ObjectIndex() const
|
||||
|
||||
|
||||
@@ -176,8 +176,8 @@ cdef class ObjectID(BaseID):
|
||||
def hex(self):
|
||||
return decode(self.data.Hex())
|
||||
|
||||
def is_direct_actor_type(self):
|
||||
return self.data.IsDirectActorType()
|
||||
def is_direct_call_type(self):
|
||||
return self.data.IsDirectCallType()
|
||||
|
||||
def is_nil(self):
|
||||
return self.data.IsNil()
|
||||
|
||||
@@ -179,6 +179,10 @@ LOG_MONITOR_MAX_OPEN_FILES = 200
|
||||
|
||||
# A constant used as object metadata to indicate the object is raw binary.
|
||||
RAW_BUFFER_METADATA = b"RAW"
|
||||
# A constant used as object metadata to indicate the object is pickled. This
|
||||
# format is only ever used for Python inline task argument values.
|
||||
PICKLE_BUFFER_METADATA = b"PICKLE"
|
||||
# A constant used as object metadata to indicate the object is pickle5 format.
|
||||
PICKLE5_BUFFER_METADATA = b"PICKLE5"
|
||||
|
||||
AUTOSCALER_RESOURCE_REQUEST_CHANNEL = b"autoscaler_resource_request"
|
||||
|
||||
@@ -134,6 +134,7 @@ class RemoteFunction(object):
|
||||
args=None,
|
||||
kwargs=None,
|
||||
num_return_vals=None,
|
||||
is_direct_call=None,
|
||||
num_cpus=None,
|
||||
num_gpus=None,
|
||||
memory=None,
|
||||
@@ -155,6 +156,8 @@ class RemoteFunction(object):
|
||||
|
||||
if num_return_vals is None:
|
||||
num_return_vals = self._num_return_vals
|
||||
if is_direct_call is None:
|
||||
is_direct_call = False
|
||||
|
||||
resources = ray.utils.resources_from_resource_arguments(
|
||||
self._num_cpus, self._num_gpus, self._memory,
|
||||
@@ -162,8 +165,11 @@ class RemoteFunction(object):
|
||||
memory, object_store_memory, resources)
|
||||
|
||||
def invocation(args, kwargs):
|
||||
list_args = ray.signature.flatten_args(self._function_signature,
|
||||
args, kwargs)
|
||||
if not args and not kwargs and not self._function_signature:
|
||||
list_args = []
|
||||
else:
|
||||
list_args = ray.signature.flatten_args(
|
||||
self._function_signature, args, kwargs)
|
||||
|
||||
if worker.mode == ray.worker.LOCAL_MODE:
|
||||
object_ids = worker.local_mode_manager.execute(
|
||||
@@ -172,7 +178,7 @@ class RemoteFunction(object):
|
||||
else:
|
||||
object_ids = worker.core_worker.submit_task(
|
||||
self._function_descriptor_list, list_args, num_return_vals,
|
||||
resources)
|
||||
is_direct_call, resources)
|
||||
|
||||
if len(object_ids) == 1:
|
||||
return object_ids[0]
|
||||
|
||||
@@ -157,8 +157,7 @@ class SerializationContext(object):
|
||||
serialization_context)
|
||||
|
||||
def id_serializer(obj):
|
||||
if isinstance(obj,
|
||||
ray.ObjectID) and obj.is_direct_actor_type():
|
||||
if isinstance(obj, ray.ObjectID) and obj.is_direct_call_type():
|
||||
raise NotImplementedError(
|
||||
"Objects produced by direct actor calls cannot be "
|
||||
"passed to other tasks as arguments.")
|
||||
@@ -191,8 +190,7 @@ class SerializationContext(object):
|
||||
custom_deserializer=actor_handle_deserializer)
|
||||
|
||||
def id_serializer(obj):
|
||||
if isinstance(obj,
|
||||
ray.ObjectID) and obj.is_direct_actor_type():
|
||||
if isinstance(obj, ray.ObjectID) and obj.is_direct_call_type():
|
||||
raise NotImplementedError(
|
||||
"Objects produced by direct actor calls cannot be "
|
||||
"passed to other tasks as arguments.")
|
||||
|
||||
@@ -7,7 +7,6 @@ import funcsigs
|
||||
from funcsigs import Parameter
|
||||
import logging
|
||||
|
||||
import ray
|
||||
from ray.utils import is_cython
|
||||
|
||||
# Logger for this module. It should be configured at the entry point
|
||||
@@ -136,12 +135,6 @@ def flatten_args(signature_parameters, args, kwargs):
|
||||
[None, 1, None, 2, None, 3, "a", 4]
|
||||
"""
|
||||
|
||||
for obj in args:
|
||||
if isinstance(obj, ray.ObjectID) and obj.is_direct_actor_type():
|
||||
raise NotImplementedError(
|
||||
"Objects produced by direct actor calls cannot be "
|
||||
"passed to other tasks as arguments.")
|
||||
|
||||
restored = _restore_parameters(signature_parameters)
|
||||
reconstructed_signature = funcsigs.Signature(parameters=restored)
|
||||
try:
|
||||
|
||||
@@ -1190,6 +1190,31 @@ def test_get_dict(ray_start_regular):
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_direct_call_simple(ray_start_regular):
|
||||
@ray.remote
|
||||
def f(x):
|
||||
return x + 1
|
||||
|
||||
f_direct = f.options(is_direct_call=True)
|
||||
print("a")
|
||||
assert ray.get(f_direct.remote(2)) == 3
|
||||
print("b")
|
||||
assert ray.get([f_direct.remote(i) for i in range(100)]) == list(
|
||||
range(1, 101))
|
||||
|
||||
|
||||
def test_direct_call_chain(ray_start_regular):
|
||||
@ray.remote
|
||||
def g(x):
|
||||
return x + 1
|
||||
|
||||
g_direct = g.options(is_direct_call=True)
|
||||
x = 0
|
||||
for _ in range(100):
|
||||
x = g_direct.remote(x)
|
||||
assert ray.get(x) == 100
|
||||
|
||||
|
||||
def test_direct_actor_enabled(ray_start_regular):
|
||||
@ray.remote
|
||||
class Actor(object):
|
||||
@@ -1240,10 +1265,6 @@ def test_direct_actor_errors(ray_start_regular):
|
||||
|
||||
a = Actor._remote(is_direct_call=True)
|
||||
|
||||
# cannot pass returns to other methods directly
|
||||
with pytest.raises(Exception):
|
||||
ray.get(f.remote(a.f.remote(2)))
|
||||
|
||||
# cannot pass returns to other methods even in a list
|
||||
with pytest.raises(Exception):
|
||||
ray.get(f.remote([a.f.remote(2)]))
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <google/protobuf/map.h>
|
||||
#include <google/protobuf/repeated_field.h>
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#include <sstream>
|
||||
#include "status.h"
|
||||
|
||||
namespace ray {
|
||||
@@ -39,7 +40,7 @@ class MessageWrapper {
|
||||
const Message &GetMessage() const { return *message_; }
|
||||
|
||||
/// Get reference of the protobuf message.
|
||||
Message &GetMutableMessage() const { return *message_; }
|
||||
Message &GetMutableMessage() { return *message_; }
|
||||
|
||||
/// Serialize the message to a string.
|
||||
const std::string Serialize() const { return message_->SerializeAsString(); }
|
||||
@@ -64,7 +65,11 @@ inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) {
|
||||
if (grpc_status.ok()) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
return Status::IOError(grpc_status.error_message());
|
||||
std::stringstream msg;
|
||||
msg << grpc_status.error_code();
|
||||
msg << ": ";
|
||||
msg << grpc_status.error_message();
|
||||
return Status::IOError(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+3
-3
@@ -18,7 +18,7 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
enum class TaskTransportType { RAYLET, DIRECT_ACTOR };
|
||||
enum class TaskTransportType { RAYLET, DIRECT };
|
||||
|
||||
class TaskID;
|
||||
class WorkerID;
|
||||
@@ -290,8 +290,8 @@ class ObjectID : public BaseID<ObjectID> {
|
||||
/// Return if this is a direct actor call object.
|
||||
///
|
||||
/// \return True if this is a direct actor object return.
|
||||
bool IsDirectActorType() const {
|
||||
return GetTransportType() == static_cast<uint8_t>(TaskTransportType::DIRECT_ACTOR);
|
||||
bool IsDirectCallType() const {
|
||||
return GetTransportType() == static_cast<uint8_t>(TaskTransportType::DIRECT);
|
||||
}
|
||||
|
||||
/// Return this object id with a changed transport type.
|
||||
|
||||
@@ -9,6 +9,9 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
typedef std::function<void(const std::shared_ptr<void>, const std::string &, int)>
|
||||
DispatchTaskCallback;
|
||||
|
||||
/// \class Task
|
||||
///
|
||||
/// A Task represents a Ray task and a specification of its execution (e.g.,
|
||||
@@ -38,6 +41,13 @@ class Task {
|
||||
ComputeDependencies();
|
||||
}
|
||||
|
||||
/// Override dispatch behaviour.
|
||||
void OnDispatchInstead(
|
||||
std::function<void(const std::shared_ptr<void>, const std::string &, int)>
|
||||
callback) {
|
||||
on_dispatch_ = callback;
|
||||
}
|
||||
|
||||
/// Get the mutable specification for the task. This specification may be
|
||||
/// updated at runtime.
|
||||
///
|
||||
@@ -62,6 +72,9 @@ class Task {
|
||||
/// \param task Task structure with updated dynamic information.
|
||||
void CopyTaskExecutionSpec(const Task &task);
|
||||
|
||||
/// Returns the override dispatch task callback, or nullptr.
|
||||
DispatchTaskCallback &OnDispatch() const { return on_dispatch_; }
|
||||
|
||||
std::string DebugString() const;
|
||||
|
||||
private:
|
||||
@@ -78,6 +91,10 @@ class Task {
|
||||
/// the TaskSpecification and execution dependencies from the
|
||||
/// TaskExecutionSpecification.
|
||||
std::vector<ObjectID> dependencies_;
|
||||
|
||||
/// For direct task calls, overrides the dispatch behaviour to send an RPC
|
||||
/// back to the submitting worker.
|
||||
mutable DispatchTaskCallback on_dispatch_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -187,8 +187,11 @@ ObjectID TaskSpecification::ActorDummyObject() const {
|
||||
}
|
||||
|
||||
bool TaskSpecification::IsDirectCall() const {
|
||||
RAY_CHECK(IsActorCreationTask());
|
||||
return message_->actor_creation_task_spec().is_direct_call();
|
||||
if (IsActorCreationTask()) {
|
||||
return message_->actor_creation_task_spec().is_direct_call();
|
||||
} else {
|
||||
return message_->is_direct_call();
|
||||
}
|
||||
}
|
||||
|
||||
int TaskSpecification::MaxActorConcurrency() const {
|
||||
|
||||
@@ -22,6 +22,8 @@ typedef std::pair<ResourceSet, FunctionDescriptor> SchedulingClassDescriptor;
|
||||
typedef int SchedulingClass;
|
||||
|
||||
/// Wrapper class of protobuf `TaskSpec`, see `common.proto` for details.
|
||||
/// TODO(ekl) we should consider passing around std::unique_ptrs<TaskSpecification>
|
||||
/// instead `const TaskSpecification`, since this class is actually mutable.
|
||||
class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {
|
||||
public:
|
||||
/// Construct an empty task specification. This should not be used directly.
|
||||
|
||||
@@ -27,7 +27,7 @@ class TaskSpecBuilder {
|
||||
const TaskID &task_id, const Language &language,
|
||||
const std::vector<std::string> &function_descriptor, const JobID &job_id,
|
||||
const TaskID &parent_task_id, uint64_t parent_counter, const TaskID &caller_id,
|
||||
uint64_t num_returns,
|
||||
uint64_t num_returns, bool is_direct_call,
|
||||
const std::unordered_map<std::string, double> &required_resources,
|
||||
const std::unordered_map<std::string, double> &required_placement_resources) {
|
||||
message_->set_type(TaskType::NORMAL_TASK);
|
||||
@@ -41,6 +41,7 @@ class TaskSpecBuilder {
|
||||
message_->set_parent_counter(parent_counter);
|
||||
message_->set_caller_id(caller_id.Binary());
|
||||
message_->set_num_returns(num_returns);
|
||||
message_->set_is_direct_call(is_direct_call);
|
||||
message_->mutable_required_resources()->insert(required_resources.begin(),
|
||||
required_resources.end());
|
||||
message_->mutable_required_placement_resources()->insert(
|
||||
|
||||
@@ -84,11 +84,14 @@ class TaskArg {
|
||||
/// Options for all tasks (actor and non-actor) except for actor creation.
|
||||
struct TaskOptions {
|
||||
TaskOptions() {}
|
||||
TaskOptions(int num_returns, std::unordered_map<std::string, double> &resources)
|
||||
: num_returns(num_returns), resources(resources) {}
|
||||
TaskOptions(int num_returns, bool is_direct_call,
|
||||
std::unordered_map<std::string, double> &resources)
|
||||
: num_returns(num_returns), is_direct_call(is_direct_call), resources(resources) {}
|
||||
|
||||
/// Number of returns of this task.
|
||||
int num_returns = 1;
|
||||
/// Whether to use the direct task transport.
|
||||
bool is_direct_call = false;
|
||||
/// Resources required by this task.
|
||||
std::unordered_map<std::string, double> resources;
|
||||
};
|
||||
|
||||
@@ -89,12 +89,13 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) {
|
||||
if (task_spec.IsNormalTask()) {
|
||||
RAY_CHECK(current_job_id_.IsNil());
|
||||
SetCurrentJobId(task_spec.JobId());
|
||||
current_task_is_direct_call_ = task_spec.IsDirectCall();
|
||||
} else if (task_spec.IsActorCreationTask()) {
|
||||
RAY_CHECK(current_job_id_.IsNil());
|
||||
SetCurrentJobId(task_spec.JobId());
|
||||
RAY_CHECK(current_actor_id_.IsNil());
|
||||
current_actor_id_ = task_spec.ActorCreationId();
|
||||
current_actor_use_direct_call_ = task_spec.IsDirectCall();
|
||||
current_task_is_direct_call_ = task_spec.IsDirectCall();
|
||||
current_actor_max_concurrency_ = task_spec.MaxActorConcurrency();
|
||||
} else if (task_spec.IsActorTask()) {
|
||||
RAY_CHECK(current_job_id_ == task_spec.JobId());
|
||||
@@ -117,8 +118,8 @@ std::shared_ptr<const TaskSpecification> WorkerContext::GetCurrentTask() const {
|
||||
|
||||
const ActorID &WorkerContext::GetCurrentActorID() const { return current_actor_id_; }
|
||||
|
||||
bool WorkerContext::CurrentActorUseDirectCall() const {
|
||||
return current_actor_use_direct_call_;
|
||||
bool WorkerContext::CurrentTaskIsDirectCall() const {
|
||||
return current_task_is_direct_call_;
|
||||
}
|
||||
|
||||
int WorkerContext::CurrentActorMaxConcurrency() const {
|
||||
|
||||
@@ -34,7 +34,7 @@ class WorkerContext {
|
||||
|
||||
const ActorID &GetCurrentActorID() const;
|
||||
|
||||
bool CurrentActorUseDirectCall() const;
|
||||
bool CurrentTaskIsDirectCall() const;
|
||||
|
||||
int CurrentActorMaxConcurrency() const;
|
||||
|
||||
@@ -47,8 +47,8 @@ class WorkerContext {
|
||||
const WorkerID worker_id_;
|
||||
JobID current_job_id_;
|
||||
ActorID current_actor_id_;
|
||||
bool current_actor_use_direct_call_ = false;
|
||||
int current_actor_max_concurrency_;
|
||||
bool current_task_is_direct_call_ = false;
|
||||
int current_actor_max_concurrency_ = 1;
|
||||
|
||||
private:
|
||||
static WorkerThreadContext &GetThreadContext(bool for_main_thread = false);
|
||||
|
||||
@@ -19,11 +19,16 @@ void BuildCommonTaskSpec(
|
||||
// Build common task spec.
|
||||
builder.SetCommonTaskSpec(task_id, function.GetLanguage(),
|
||||
function.GetFunctionDescriptor(), job_id, current_task_id,
|
||||
task_index, caller_id, num_returns, required_resources,
|
||||
required_placement_resources);
|
||||
task_index, caller_id, num_returns,
|
||||
transport_type == ray::TaskTransportType::DIRECT,
|
||||
required_resources, required_placement_resources);
|
||||
// Set task arguments.
|
||||
for (const auto &arg : args) {
|
||||
if (arg.IsPassedByReference()) {
|
||||
if (transport_type == ray::TaskTransportType::RAYLET) {
|
||||
RAY_CHECK(!arg.GetReference().IsDirectCallType())
|
||||
<< "NotImplemented: passing direct call objects to other tasks";
|
||||
}
|
||||
builder.AddByRefArg(arg.GetReference());
|
||||
} else {
|
||||
builder.AddByValueArg(arg.GetValue());
|
||||
@@ -44,7 +49,7 @@ void GroupObjectIdsByStoreProvider(const std::vector<ObjectID> &object_ids,
|
||||
absl::flat_hash_set<ObjectID> *plasma_object_ids,
|
||||
absl::flat_hash_set<ObjectID> *memory_object_ids) {
|
||||
for (const auto &object_id : object_ids) {
|
||||
if (object_id.IsDirectActorType()) {
|
||||
if (object_id.IsDirectCallType()) {
|
||||
memory_object_ids->insert(object_id);
|
||||
} else {
|
||||
plasma_object_ids->insert(object_id);
|
||||
@@ -70,11 +75,12 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
||||
check_signals_(check_signals),
|
||||
worker_context_(worker_type, job_id),
|
||||
io_work_(io_service_),
|
||||
client_call_manager_(new rpc::ClientCallManager(io_service_)),
|
||||
heartbeat_timer_(io_service_),
|
||||
worker_server_(WorkerTypeString(worker_type), 0 /* let grpc choose a port */),
|
||||
core_worker_server_(WorkerTypeString(worker_type), 0 /* let grpc choose a port */),
|
||||
gcs_client_(gcs_options),
|
||||
client_call_manager_(io_service_),
|
||||
memory_store_(std::make_shared<CoreWorkerMemoryStore>()),
|
||||
memory_store_provider_(memory_store_),
|
||||
task_execution_service_work_(task_execution_service_),
|
||||
task_execution_callback_(task_execution_callback),
|
||||
grpc_service_(io_service_, *this) {
|
||||
@@ -95,24 +101,21 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
||||
profiler_ = std::make_shared<worker::Profiler>(worker_context_, node_ip_address,
|
||||
io_service_, gcs_client_);
|
||||
|
||||
// Initialize task execution.
|
||||
// Initialize task receivers.
|
||||
if (worker_type_ == WorkerType::WORKER) {
|
||||
RAY_CHECK(task_execution_callback_ != nullptr);
|
||||
|
||||
// Initialize task receivers.
|
||||
auto execute_task = std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3);
|
||||
raylet_task_receiver_ = std::unique_ptr<CoreWorkerRayletTaskReceiver>(
|
||||
new CoreWorkerRayletTaskReceiver(raylet_client_, execute_task, exit_handler));
|
||||
direct_actor_task_receiver_ = std::unique_ptr<CoreWorkerDirectActorTaskReceiver>(
|
||||
new CoreWorkerDirectActorTaskReceiver(worker_context_, task_execution_service_,
|
||||
worker_server_, execute_task,
|
||||
exit_handler));
|
||||
worker_server_.RegisterService(grpc_service_);
|
||||
direct_task_receiver_ =
|
||||
std::unique_ptr<CoreWorkerDirectTaskReceiver>(new CoreWorkerDirectTaskReceiver(
|
||||
worker_context_, task_execution_service_, execute_task, exit_handler));
|
||||
}
|
||||
|
||||
// Start RPC server after all the task receivers are properly initialized.
|
||||
worker_server_.Run();
|
||||
core_worker_server_.RegisterService(grpc_service_);
|
||||
core_worker_server_.Run();
|
||||
|
||||
// Initialize raylet client.
|
||||
// TODO(zhijunfu): currently RayletClient would crash in its constructor if it cannot
|
||||
@@ -120,15 +123,15 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
||||
// so that the worker (java/python .etc) can retrieve and handle the error
|
||||
// instead of crashing.
|
||||
auto grpc_client = rpc::NodeManagerWorkerClient::make(
|
||||
node_ip_address, node_manager_port, client_call_manager_);
|
||||
node_ip_address, node_manager_port, *client_call_manager_);
|
||||
raylet_client_ = std::unique_ptr<RayletClient>(new RayletClient(
|
||||
std::move(grpc_client), raylet_socket,
|
||||
WorkerID::FromBinary(worker_context_.GetWorkerID().Binary()),
|
||||
(worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(),
|
||||
language_, worker_server_.GetPort()));
|
||||
language_, core_worker_server_.GetPort()));
|
||||
// Unfortunately the raylet client has to be constructed after the receivers.
|
||||
if (direct_actor_task_receiver_ != nullptr) {
|
||||
direct_actor_task_receiver_->Init(*raylet_client_);
|
||||
if (direct_task_receiver_ != nullptr) {
|
||||
direct_task_receiver_->Init(*raylet_client_);
|
||||
}
|
||||
|
||||
// Set timer to periodically send heartbeats containing active object IDs to the raylet.
|
||||
@@ -149,7 +152,6 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
||||
|
||||
plasma_store_provider_.reset(
|
||||
new CoreWorkerPlasmaStoreProvider(store_socket, raylet_client_, check_signals_));
|
||||
memory_store_provider_.reset(new CoreWorkerMemoryStoreProvider(memory_store_));
|
||||
|
||||
// Create an entry for the driver task in the task table. This task is
|
||||
// added immediately with status RUNNING. This allows us to push errors
|
||||
@@ -162,10 +164,10 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
||||
std::vector<std::string> empty_descriptor;
|
||||
std::unordered_map<std::string, double> empty_resources;
|
||||
const TaskID task_id = TaskID::ForDriverTask(worker_context_.GetCurrentJobID());
|
||||
builder.SetCommonTaskSpec(task_id, language_, empty_descriptor,
|
||||
worker_context_.GetCurrentJobID(),
|
||||
TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()),
|
||||
0, GetCallerId(), 0, empty_resources, empty_resources);
|
||||
builder.SetCommonTaskSpec(
|
||||
task_id, language_, empty_descriptor, worker_context_.GetCurrentJobID(),
|
||||
TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()), 0, GetCallerId(), 0,
|
||||
false, empty_resources, empty_resources);
|
||||
|
||||
std::shared_ptr<gcs::TaskTableData> data = std::make_shared<gcs::TaskTableData>();
|
||||
data->mutable_task()->mutable_task_spec()->CopyFrom(builder.Build().GetMessage());
|
||||
@@ -173,11 +175,18 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
|
||||
SetCurrentTaskId(task_id);
|
||||
}
|
||||
|
||||
// TODO(edoakes): why don't we just share the memory store provider?
|
||||
direct_actor_submitter_ = std::unique_ptr<CoreWorkerDirectActorTaskSubmitter>(
|
||||
new CoreWorkerDirectActorTaskSubmitter(
|
||||
io_service_, std::unique_ptr<CoreWorkerMemoryStoreProvider>(
|
||||
new CoreWorkerMemoryStoreProvider(memory_store_))));
|
||||
new CoreWorkerDirectActorTaskSubmitter(*client_call_manager_,
|
||||
memory_store_provider_));
|
||||
|
||||
direct_task_submitter_ =
|
||||
std::unique_ptr<CoreWorkerDirectTaskSubmitter>(new CoreWorkerDirectTaskSubmitter(
|
||||
*raylet_client_,
|
||||
[this](WorkerAddress addr) {
|
||||
return std::shared_ptr<rpc::CoreWorkerClient>(new rpc::CoreWorkerClient(
|
||||
addr.first, addr.second, *client_call_manager_));
|
||||
},
|
||||
memory_store_provider_));
|
||||
}
|
||||
|
||||
CoreWorker::~CoreWorker() {
|
||||
@@ -308,9 +317,9 @@ Status CoreWorker::Get(const std::vector<ObjectID> &ids, const int64_t timeout_m
|
||||
local_timeout_ms = std::max(static_cast<int64_t>(0),
|
||||
timeout_ms - (current_time_ms() - start_time));
|
||||
}
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_->Get(memory_object_ids, local_timeout_ms,
|
||||
worker_context_.GetCurrentTaskID(),
|
||||
&result_map, &got_exception));
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_.Get(memory_object_ids, local_timeout_ms,
|
||||
worker_context_.GetCurrentTaskID(),
|
||||
&result_map, &got_exception));
|
||||
}
|
||||
|
||||
// If any of the objects have been promoted to plasma, then we retry their
|
||||
@@ -334,9 +343,9 @@ Status CoreWorker::Get(const std::vector<ObjectID> &ids, const int64_t timeout_m
|
||||
for (const auto &id : promoted_plasma_ids) {
|
||||
auto it = result_map.find(id);
|
||||
if (it == result_map.end()) {
|
||||
result_map.erase(id.WithTransportType(TaskTransportType::DIRECT_ACTOR));
|
||||
result_map.erase(id.WithTransportType(TaskTransportType::DIRECT));
|
||||
} else {
|
||||
result_map[id.WithTransportType(TaskTransportType::DIRECT_ACTOR)] = it->second;
|
||||
result_map[id.WithTransportType(TaskTransportType::DIRECT)] = it->second;
|
||||
}
|
||||
result_map.erase(id);
|
||||
}
|
||||
@@ -360,10 +369,10 @@ Status CoreWorker::Get(const std::vector<ObjectID> &ids, const int64_t timeout_m
|
||||
|
||||
Status CoreWorker::Contains(const ObjectID &object_id, bool *has_object) {
|
||||
bool found = false;
|
||||
if (object_id.IsDirectActorType()) {
|
||||
if (object_id.IsDirectCallType()) {
|
||||
// Note that the memory store returns false if the object value is
|
||||
// ErrorType::OBJECT_IN_PLASMA.
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_->Contains(object_id, &found));
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_.Contains(object_id, &found));
|
||||
}
|
||||
if (!found) {
|
||||
// We check plasma as a fallback in all cases, since a direct call object
|
||||
@@ -413,7 +422,7 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
|
||||
if (memory_object_ids.size() > 0) {
|
||||
// TODO(ekl) for memory objects that are ErrorType::OBJECT_IN_PLASMA, we should
|
||||
// consider waiting on them in plasma as well to ensure they are local.
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_->Wait(
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_.Wait(
|
||||
memory_object_ids, std::max(0, static_cast<int>(ready.size()) - num_objects),
|
||||
/*timeout_ms=*/0, worker_context_.GetCurrentTaskID(), &ready));
|
||||
}
|
||||
@@ -431,8 +440,8 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
|
||||
}
|
||||
if (memory_object_ids.size() > 0) {
|
||||
RAY_RETURN_NOT_OK(
|
||||
memory_store_provider_->Wait(memory_object_ids, num_objects, timeout_ms,
|
||||
worker_context_.GetCurrentTaskID(), &ready));
|
||||
memory_store_provider_.Wait(memory_object_ids, num_objects, timeout_ms,
|
||||
worker_context_.GetCurrentTaskID(), &ready));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -453,7 +462,7 @@ Status CoreWorker::Delete(const std::vector<ObjectID> &object_ids, bool local_on
|
||||
|
||||
RAY_RETURN_NOT_OK(plasma_store_provider_->Delete(plasma_object_ids, local_only,
|
||||
delete_creating_tasks));
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_->Delete(memory_object_ids));
|
||||
RAY_RETURN_NOT_OK(memory_store_provider_.Delete(memory_object_ids));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@@ -494,8 +503,7 @@ Status CoreWorker::SubmitTaskToRaylet(const TaskSpecification &task_spec) {
|
||||
|
||||
if (task_deps->size() > 0) {
|
||||
for (size_t i = 0; i < num_returns; i++) {
|
||||
reference_counter_.SetDependencies(task_spec.ReturnId(i, TaskTransportType::RAYLET),
|
||||
task_deps);
|
||||
reference_counter_.SetDependencies(task_spec.ReturnIdForPlasma(i), task_deps);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -511,11 +519,19 @@ Status CoreWorker::SubmitTask(const RayFunction &function,
|
||||
const auto task_id =
|
||||
TaskID::ForNormalTask(worker_context_.GetCurrentJobID(),
|
||||
worker_context_.GetCurrentTaskID(), next_task_index);
|
||||
BuildCommonTaskSpec(builder, worker_context_.GetCurrentJobID(), task_id,
|
||||
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(),
|
||||
function, args, task_options.num_returns, task_options.resources,
|
||||
{}, TaskTransportType::RAYLET, return_ids);
|
||||
return SubmitTaskToRaylet(builder.Build());
|
||||
|
||||
// TODO(ekl) offload task building onto a thread pool for performance
|
||||
BuildCommonTaskSpec(
|
||||
builder, worker_context_.GetCurrentJobID(), task_id,
|
||||
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), function, args,
|
||||
task_options.num_returns, task_options.resources, {},
|
||||
task_options.is_direct_call ? TaskTransportType::DIRECT : TaskTransportType::RAYLET,
|
||||
return_ids);
|
||||
if (task_options.is_direct_call) {
|
||||
return direct_task_submitter_->SubmitTask(builder.Build());
|
||||
} else {
|
||||
return raylet_client_->SubmitTask(builder.Build());
|
||||
}
|
||||
}
|
||||
|
||||
Status CoreWorker::CreateActor(const RayFunction &function,
|
||||
@@ -562,7 +578,7 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
|
||||
|
||||
const bool is_direct_call = actor_handle->IsDirectCallActor();
|
||||
const TaskTransportType transport_type =
|
||||
is_direct_call ? TaskTransportType::DIRECT_ACTOR : TaskTransportType::RAYLET;
|
||||
is_direct_call ? TaskTransportType::DIRECT : TaskTransportType::RAYLET;
|
||||
|
||||
// Build common task spec.
|
||||
TaskSpecBuilder builder;
|
||||
@@ -675,7 +691,7 @@ Status CoreWorker::AllocateReturnObjects(
|
||||
bool object_already_exists = false;
|
||||
std::shared_ptr<Buffer> data_buffer;
|
||||
if (data_sizes[i] > 0) {
|
||||
if (worker_context_.CurrentActorUseDirectCall() &&
|
||||
if (worker_context_.CurrentTaskIsDirectCall() &&
|
||||
static_cast<int64_t>(data_sizes[i]) <
|
||||
RayConfig::instance().max_direct_call_object_size()) {
|
||||
data_buffer = std::make_shared<LocalMemoryBuffer>(data_sizes[i]);
|
||||
@@ -710,8 +726,8 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
|
||||
std::vector<ObjectID> arg_reference_ids;
|
||||
RAY_CHECK_OK(BuildArgsForExecutor(task_spec, &args, &arg_reference_ids));
|
||||
|
||||
const auto transport_type = worker_context_.CurrentActorUseDirectCall()
|
||||
? TaskTransportType::DIRECT_ACTOR
|
||||
const auto transport_type = worker_context_.CurrentTaskIsDirectCall()
|
||||
? TaskTransportType::DIRECT
|
||||
: TaskTransportType::RAYLET;
|
||||
std::vector<ObjectID> return_ids;
|
||||
for (size_t i = 0; i < task_spec.NumReturns(); i++) {
|
||||
@@ -745,7 +761,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
|
||||
RAY_LOG(FATAL) << "Task " << task_spec.TaskId() << " failed to seal object "
|
||||
<< return_ids[i] << " in store: " << status.message();
|
||||
}
|
||||
} else if (!worker_context_.CurrentActorUseDirectCall()) {
|
||||
} else if (!worker_context_.CurrentTaskIsDirectCall()) {
|
||||
if (!Put(*return_objects->at(i), return_ids[i]).ok()) {
|
||||
RAY_LOG(FATAL) << "Task " << task_spec.TaskId() << " failed to put object "
|
||||
<< return_ids[i] << " in store: " << status.message();
|
||||
@@ -816,7 +832,7 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task,
|
||||
void CoreWorker::HandleAssignTask(const rpc::AssignTaskRequest &request,
|
||||
rpc::AssignTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
if (worker_context_.CurrentActorUseDirectCall()) {
|
||||
if (worker_context_.CurrentTaskIsDirectCall()) {
|
||||
send_reply_callback(Status::Invalid("This actor only accepts direct calls."), nullptr,
|
||||
nullptr);
|
||||
return;
|
||||
@@ -827,12 +843,11 @@ void CoreWorker::HandleAssignTask(const rpc::AssignTaskRequest &request,
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorker::HandleDirectActorAssignTask(
|
||||
const rpc::DirectActorAssignTaskRequest &request,
|
||||
rpc::DirectActorAssignTaskReply *reply, rpc::SendReplyCallback send_reply_callback) {
|
||||
void CoreWorker::HandlePushTask(const rpc::PushTaskRequest &request,
|
||||
rpc::PushTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
task_execution_service_.post([=] {
|
||||
direct_actor_task_receiver_->HandleDirectActorAssignTask(request, reply,
|
||||
send_reply_callback);
|
||||
direct_task_receiver_->HandlePushTask(request, reply, send_reply_callback);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -841,9 +856,19 @@ void CoreWorker::HandleDirectActorCallArgWaitComplete(
|
||||
rpc::DirectActorCallArgWaitCompleteReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
task_execution_service_.post([=] {
|
||||
direct_actor_task_receiver_->HandleDirectActorCallArgWaitComplete(
|
||||
request, reply, send_reply_callback);
|
||||
direct_task_receiver_->HandleDirectActorCallArgWaitComplete(request, reply,
|
||||
send_reply_callback);
|
||||
});
|
||||
}
|
||||
|
||||
void CoreWorker::HandleWorkerLeaseGranted(const rpc::WorkerLeaseGrantedRequest &request,
|
||||
rpc::WorkerLeaseGrantedReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
// Run this directly since the main thread may be tied up processing a task and
|
||||
// we need to still continue processing these scheduling operations in the backend.
|
||||
direct_task_submitter_->HandleWorkerLeaseGranted(
|
||||
std::make_pair(request.address(), request.port()));
|
||||
send_reply_callback(Status::OK(), nullptr, nullptr);
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -14,22 +14,24 @@
|
||||
#include "ray/core_worker/store_provider/memory_store_provider.h"
|
||||
#include "ray/core_worker/store_provider/plasma_store_provider.h"
|
||||
#include "ray/core_worker/transport/direct_actor_transport.h"
|
||||
#include "ray/core_worker/transport/direct_task_transport.h"
|
||||
#include "ray/core_worker/transport/raylet_transport.h"
|
||||
#include "ray/gcs/redis_gcs_client.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
#include "ray/rpc/node_manager/node_manager_client.h"
|
||||
#include "ray/rpc/worker/worker_client.h"
|
||||
#include "ray/rpc/worker/worker_server.h"
|
||||
#include "ray/rpc/worker/core_worker_client.h"
|
||||
#include "ray/rpc/worker/core_worker_server.h"
|
||||
|
||||
/// The set of gRPC handlers and their associated level of concurrency. If you want to
|
||||
/// add a new call to the worker gRPC server, do the following:
|
||||
/// 1) Add the rpc to the WorkerService in core_worker.proto, e.g., "ExampleCall"
|
||||
/// 1) Add the rpc to the CoreWorkerService in core_worker.proto, e.g., "ExampleCall"
|
||||
/// 2) Add a new handler to the macro below: "RAY_CORE_WORKER_RPC_HANDLER(ExampleCall, 1)"
|
||||
/// 3) Add a method to the CoreWorker class below: "CoreWorker::HandleExampleCall"
|
||||
#define RAY_CORE_WORKER_RPC_HANDLERS \
|
||||
RAY_CORE_WORKER_RPC_HANDLER(AssignTask, 5) \
|
||||
RAY_CORE_WORKER_RPC_HANDLER(DirectActorAssignTask, 9999) \
|
||||
RAY_CORE_WORKER_RPC_HANDLER(DirectActorCallArgWaitComplete, 100)
|
||||
#define RAY_CORE_WORKER_RPC_HANDLERS \
|
||||
RAY_CORE_WORKER_RPC_HANDLER(AssignTask, 5) \
|
||||
RAY_CORE_WORKER_RPC_HANDLER(PushTask, 9999) \
|
||||
RAY_CORE_WORKER_RPC_HANDLER(DirectActorCallArgWaitComplete, 100) \
|
||||
RAY_CORE_WORKER_RPC_HANDLER(WorkerLeaseGranted, 5)
|
||||
|
||||
namespace ray {
|
||||
|
||||
@@ -319,29 +321,32 @@ class CoreWorker {
|
||||
const std::vector<std::shared_ptr<Buffer>> &metadatas,
|
||||
std::vector<std::shared_ptr<RayObject>> *return_objects);
|
||||
|
||||
/* Handlers for the worker's gRPC server. These are executed on the io_service_ and post
|
||||
* work to the appropriate event loop.
|
||||
/**
|
||||
* The following methods are handlers for the core worker's gRPC server, which follow
|
||||
* a macro-generated call convention. These are executed on the io_service_ and
|
||||
* post work to the appropriate event loop.
|
||||
*/
|
||||
|
||||
/// Handle an "AssignTask" event corresponding to scheduling a normal or an actor task
|
||||
/// on this worker from the raylet.
|
||||
/// Implements gRPC server handler.
|
||||
void HandleAssignTask(const rpc::AssignTaskRequest &request,
|
||||
rpc::AssignTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback);
|
||||
|
||||
/// Handle a "DirectActorAssignTask" event corresponding to scheduling an actor task
|
||||
/// on this worker from another worker.
|
||||
void HandleDirectActorAssignTask(const rpc::DirectActorAssignTaskRequest &request,
|
||||
rpc::DirectActorAssignTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback);
|
||||
/// Implements gRPC server handler.
|
||||
void HandlePushTask(const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback);
|
||||
|
||||
/// Handle a "DirectActorAssignTask" event corresponding to the raylet notifiying this
|
||||
/// worker that an argument is ready.
|
||||
/// Implements gRPC server handler.
|
||||
void HandleDirectActorCallArgWaitComplete(
|
||||
const rpc::DirectActorCallArgWaitCompleteRequest &request,
|
||||
rpc::DirectActorCallArgWaitCompleteReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback);
|
||||
|
||||
/// Implements gRPC server handler.
|
||||
void HandleWorkerLeaseGranted(const rpc::WorkerLeaseGrantedRequest &request,
|
||||
rpc::WorkerLeaseGrantedReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback);
|
||||
|
||||
private:
|
||||
/// Run the io_service_ event loop. This should be called in a background thread.
|
||||
void RunIOService();
|
||||
@@ -446,19 +451,19 @@ class CoreWorker {
|
||||
/// Keeps the io_service_ alive.
|
||||
boost::asio::io_service::work io_work_;
|
||||
|
||||
/// Shared client call manager.
|
||||
std::unique_ptr<rpc::ClientCallManager> client_call_manager_;
|
||||
|
||||
/// Timer used to periodically send heartbeat containing active object IDs to the
|
||||
/// raylet.
|
||||
boost::asio::steady_timer heartbeat_timer_;
|
||||
|
||||
/// RPC server used to receive tasks to execute.
|
||||
rpc::GrpcServer worker_server_;
|
||||
rpc::GrpcServer core_worker_server_;
|
||||
|
||||
// Client to the GCS shared by core worker interfaces.
|
||||
gcs::RedisGcsClient gcs_client_;
|
||||
|
||||
/// The `ClientCallManager` object that is shared by all `NodeManagerClient`s.
|
||||
rpc::ClientCallManager client_call_manager_;
|
||||
|
||||
// Client to the raylet shared by core worker interfaces.
|
||||
std::unique_ptr<RayletClient> raylet_client_;
|
||||
|
||||
@@ -479,7 +484,7 @@ class CoreWorker {
|
||||
std::unique_ptr<CoreWorkerPlasmaStoreProvider> plasma_store_provider_;
|
||||
|
||||
/// In-memory store interface.
|
||||
std::unique_ptr<CoreWorkerMemoryStoreProvider> memory_store_provider_;
|
||||
CoreWorkerMemoryStoreProvider memory_store_provider_;
|
||||
|
||||
///
|
||||
/// Fields related to task submission.
|
||||
@@ -488,6 +493,9 @@ class CoreWorker {
|
||||
// Interface to submit tasks directly to other actors.
|
||||
std::unique_ptr<CoreWorkerDirectActorTaskSubmitter> direct_actor_submitter_;
|
||||
|
||||
// Interface to submit non-actor tasks directly to leased workers.
|
||||
std::unique_ptr<CoreWorkerDirectTaskSubmitter> direct_task_submitter_;
|
||||
|
||||
/// Map from actor ID to a handle to that actor.
|
||||
absl::flat_hash_map<ActorID, std::unique_ptr<ActorHandle>> actor_handles_;
|
||||
|
||||
@@ -519,10 +527,10 @@ class CoreWorker {
|
||||
std::unique_ptr<CoreWorkerRayletTaskReceiver> raylet_task_receiver_;
|
||||
|
||||
/// Common rpc service for all worker modules.
|
||||
rpc::WorkerGrpcService grpc_service_;
|
||||
rpc::CoreWorkerGrpcService grpc_service_;
|
||||
|
||||
// Interface that receives tasks from direct actor calls.
|
||||
std::unique_ptr<CoreWorkerDirectActorTaskReceiver> direct_actor_task_receiver_;
|
||||
std::unique_ptr<CoreWorkerDirectTaskReceiver> direct_task_receiver_;
|
||||
|
||||
friend class CoreWorkerTest;
|
||||
};
|
||||
|
||||
@@ -106,32 +106,66 @@ std::shared_ptr<RayObject> GetRequest::Get(const ObjectID &object_id) const {
|
||||
|
||||
CoreWorkerMemoryStore::CoreWorkerMemoryStore() {}
|
||||
|
||||
Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &object) {
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
auto iter = objects_.find(object_id);
|
||||
if (iter != objects_.end()) {
|
||||
return Status::ObjectExists("object already exists in the memory store");
|
||||
void CoreWorkerMemoryStore::GetAsync(
|
||||
const ObjectID &object_id, std::function<void(std::shared_ptr<RayObject>)> callback) {
|
||||
std::shared_ptr<RayObject> ptr;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto iter = objects_.find(object_id);
|
||||
if (iter != objects_.end()) {
|
||||
ptr = iter->second;
|
||||
} else {
|
||||
object_async_get_requests_[object_id].push_back(callback);
|
||||
}
|
||||
}
|
||||
// It's important for performance to run the callback outside the lock.
|
||||
if (ptr != nullptr) {
|
||||
callback(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &object) {
|
||||
std::vector<std::function<void(std::shared_ptr<RayObject>)>> async_callbacks;
|
||||
auto object_entry =
|
||||
std::make_shared<RayObject>(object.GetData(), object.GetMetadata(), true);
|
||||
|
||||
bool should_add_entry = true;
|
||||
auto object_request_iter = object_get_requests_.find(object_id);
|
||||
if (object_request_iter != object_get_requests_.end()) {
|
||||
auto &get_requests = object_request_iter->second;
|
||||
for (auto &get_request : get_requests) {
|
||||
get_request->Set(object_id, object_entry);
|
||||
if (get_request->ShouldRemoveObjects()) {
|
||||
should_add_entry = false;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto iter = objects_.find(object_id);
|
||||
if (iter != objects_.end()) {
|
||||
return Status::ObjectExists("object already exists in the memory store");
|
||||
}
|
||||
|
||||
auto async_callback_it = object_async_get_requests_.find(object_id);
|
||||
if (async_callback_it != object_async_get_requests_.end()) {
|
||||
auto &callbacks = async_callback_it->second;
|
||||
async_callbacks = std::move(callbacks);
|
||||
object_async_get_requests_.erase(async_callback_it);
|
||||
}
|
||||
|
||||
bool should_add_entry = true;
|
||||
auto object_request_iter = object_get_requests_.find(object_id);
|
||||
if (object_request_iter != object_get_requests_.end()) {
|
||||
auto &get_requests = object_request_iter->second;
|
||||
for (auto &get_request : get_requests) {
|
||||
get_request->Set(object_id, object_entry);
|
||||
if (get_request->ShouldRemoveObjects()) {
|
||||
should_add_entry = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (should_add_entry) {
|
||||
// If there is no existing get request, then add the `RayObject` to map.
|
||||
objects_.emplace(object_id, object_entry);
|
||||
}
|
||||
}
|
||||
|
||||
if (should_add_entry) {
|
||||
// If there is no existing get request, then add the `RayObject` to map.
|
||||
objects_.emplace(object_id, object_entry);
|
||||
// It's important for performance to run the callbacks outside the lock.
|
||||
for (const auto &cb : async_callbacks) {
|
||||
cb(object_entry);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@@ -147,7 +181,7 @@ Status CoreWorkerMemoryStore::Get(const std::vector<ObjectID> &object_ids,
|
||||
absl::flat_hash_set<ObjectID> remaining_ids;
|
||||
absl::flat_hash_set<ObjectID> ids_to_remove;
|
||||
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
absl::MutexLock lock(&mu_);
|
||||
// Check for existing objects and see if this get request can be fullfilled.
|
||||
for (size_t i = 0; i < object_ids.size(); i++) {
|
||||
const auto &object_id = object_ids[i];
|
||||
@@ -192,7 +226,7 @@ Status CoreWorkerMemoryStore::Get(const std::vector<ObjectID> &object_ids,
|
||||
get_request->Wait(timeout_ms);
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
absl::MutexLock lock(&mu_);
|
||||
// Populate results.
|
||||
for (size_t i = 0; i < object_ids.size(); i++) {
|
||||
const auto &object_id = object_ids[i];
|
||||
@@ -223,14 +257,14 @@ Status CoreWorkerMemoryStore::Get(const std::vector<ObjectID> &object_ids,
|
||||
}
|
||||
|
||||
void CoreWorkerMemoryStore::Delete(const std::vector<ObjectID> &object_ids) {
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
absl::MutexLock lock(&mu_);
|
||||
for (const auto &object_id : object_ids) {
|
||||
objects_.erase(object_id);
|
||||
}
|
||||
}
|
||||
|
||||
bool CoreWorkerMemoryStore::Contains(const ObjectID &object_id) {
|
||||
std::unique_lock<std::mutex> lock(lock_);
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto it = objects_.find(object_id);
|
||||
// If obj is in plasma, we defer to the plasma store for the Contains() call.
|
||||
return it != objects_.end() && !it->second->IsInPlasmaError();
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/core_worker/common.h"
|
||||
@@ -39,6 +40,15 @@ class CoreWorkerMemoryStore {
|
||||
Status Get(const std::vector<ObjectID> &object_ids, int num_objects, int64_t timeout_ms,
|
||||
bool remove_after_get, std::vector<std::shared_ptr<RayObject>> *results);
|
||||
|
||||
/// Asynchronously get an object from the object store. The object will not be removed
|
||||
/// from storage after GetAsync (TODO(ekl): integrate this with object GC).
|
||||
///
|
||||
/// \param[in] object_id The object id to get.
|
||||
/// \param[in] callback The callback to run with the reference to the retrieved
|
||||
/// object value once available.
|
||||
void GetAsync(const ObjectID &object_id,
|
||||
std::function<void(std::shared_ptr<RayObject>)> callback);
|
||||
|
||||
/// Delete a list of objects from the object store.
|
||||
///
|
||||
/// \param[in] object_ids IDs of the objects to delete.
|
||||
@@ -53,14 +63,19 @@ class CoreWorkerMemoryStore {
|
||||
|
||||
private:
|
||||
/// Map from object ID to `RayObject`.
|
||||
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> objects_;
|
||||
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> objects_ GUARDED_BY(mu_);
|
||||
|
||||
/// Map from object ID to its get requests.
|
||||
absl::flat_hash_map<ObjectID, std::vector<std::shared_ptr<GetRequest>>>
|
||||
object_get_requests_;
|
||||
object_get_requests_ GUARDED_BY(mu_);
|
||||
|
||||
/// Map from object ID to its async get requests.
|
||||
absl::flat_hash_map<ObjectID,
|
||||
std::vector<std::function<void(std::shared_ptr<RayObject>)>>>
|
||||
object_async_get_requests_ GUARDED_BY(mu_);
|
||||
|
||||
/// Protect the two maps above.
|
||||
std::mutex lock_;
|
||||
absl::Mutex mu_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -19,6 +19,11 @@ class CoreWorkerMemoryStoreProvider {
|
||||
public:
|
||||
CoreWorkerMemoryStoreProvider(std::shared_ptr<CoreWorkerMemoryStore> store);
|
||||
|
||||
void GetAsync(const ObjectID &object_id,
|
||||
std::function<void(std::shared_ptr<RayObject>)> callback) {
|
||||
store_->GetAsync(object_id, callback);
|
||||
}
|
||||
|
||||
Status Put(const RayObject &object, const ObjectID &object_id);
|
||||
|
||||
Status Get(const absl::flat_hash_set<ObjectID> &object_ids, int64_t timeout_ms,
|
||||
|
||||
@@ -44,16 +44,6 @@ static void flushall_redis(void) {
|
||||
redisFree(context);
|
||||
}
|
||||
|
||||
std::shared_ptr<Buffer> GenerateRandomBuffer() {
|
||||
auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count();
|
||||
std::mt19937 gen(seed);
|
||||
std::uniform_int_distribution<> dis(1, 10);
|
||||
std::uniform_int_distribution<> value_dis(1, 255);
|
||||
|
||||
std::vector<uint8_t> arg1(dis(gen), value_dis(gen));
|
||||
return std::make_shared<LocalMemoryBuffer>(arg1.data(), arg1.size(), true);
|
||||
}
|
||||
|
||||
ActorID CreateActorHelper(CoreWorker &worker,
|
||||
std::unordered_map<std::string, double> &resources,
|
||||
bool is_direct_call, uint64_t max_reconstructions) {
|
||||
@@ -279,16 +269,15 @@ void CoreWorkerTest::TestActorTask(std::unordered_map<std::string, double> &reso
|
||||
args.emplace_back(
|
||||
TaskArg::PassByValue(std::make_shared<RayObject>(buffer2, nullptr)));
|
||||
|
||||
TaskOptions options{1, resources};
|
||||
TaskOptions options{1, false, resources};
|
||||
std::vector<ObjectID> return_ids;
|
||||
RayFunction func(ray::Language::PYTHON, {});
|
||||
|
||||
RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids));
|
||||
ASSERT_EQ(return_ids.size(), 1);
|
||||
ASSERT_TRUE(return_ids[0].IsReturnObject());
|
||||
ASSERT_EQ(
|
||||
static_cast<TaskTransportType>(return_ids[0].GetTransportType()),
|
||||
is_direct_call ? TaskTransportType::DIRECT_ACTOR : TaskTransportType::RAYLET);
|
||||
ASSERT_EQ(static_cast<TaskTransportType>(return_ids[0].GetTransportType()),
|
||||
is_direct_call ? TaskTransportType::DIRECT : TaskTransportType::RAYLET);
|
||||
|
||||
std::vector<std::shared_ptr<ray::RayObject>> results;
|
||||
RAY_CHECK_OK(driver.Get(return_ids, -1, &results));
|
||||
@@ -320,7 +309,7 @@ void CoreWorkerTest::TestActorTask(std::unordered_map<std::string, double> &reso
|
||||
args.emplace_back(
|
||||
TaskArg::PassByValue(std::make_shared<RayObject>(buffer2, nullptr)));
|
||||
|
||||
TaskOptions options{1, resources};
|
||||
TaskOptions options{1, false, resources};
|
||||
std::vector<ObjectID> return_ids;
|
||||
RayFunction func(ray::Language::PYTHON, {});
|
||||
auto status = driver.SubmitActorTask(actor_id, func, args, options, &return_ids);
|
||||
@@ -380,7 +369,7 @@ void CoreWorkerTest::TestActorReconstruction(
|
||||
args.emplace_back(
|
||||
TaskArg::PassByValue(std::make_shared<RayObject>(buffer1, nullptr)));
|
||||
|
||||
TaskOptions options{1, resources};
|
||||
TaskOptions options{1, false, resources};
|
||||
std::vector<ObjectID> return_ids;
|
||||
RayFunction func(ray::Language::PYTHON, {});
|
||||
|
||||
@@ -425,7 +414,7 @@ void CoreWorkerTest::TestActorFailure(std::unordered_map<std::string, double> &r
|
||||
args.emplace_back(
|
||||
TaskArg::PassByValue(std::make_shared<RayObject>(buffer1, nullptr)));
|
||||
|
||||
TaskOptions options{1, resources};
|
||||
TaskOptions options{1, false, resources};
|
||||
std::vector<ObjectID> return_ids;
|
||||
RayFunction func(ray::Language::PYTHON, {});
|
||||
|
||||
@@ -486,7 +475,7 @@ TEST_F(ZeroNodeTest, TestTaskArg) {
|
||||
ASSERT_EQ(*data, *buffer);
|
||||
}
|
||||
|
||||
// Performance batchmark for `DirectActorAssignTaskRequest` creation.
|
||||
// Performance batchmark for `PushTaskRequest` creation.
|
||||
TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
|
||||
// Create a dummy actor handle, and then create a number of `TaskSpec`
|
||||
// to benchmark performance.
|
||||
@@ -505,20 +494,21 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
|
||||
function.GetFunctionDescriptor());
|
||||
|
||||
// Manually create `num_tasks` task specs, and for each of them create a
|
||||
// `DirectActorAssignTaskRequest`, this is to batch performance of TaskSpec
|
||||
// `PushTaskRequest`, this is to batch performance of TaskSpec
|
||||
// creation/copy/destruction.
|
||||
int64_t start_ms = current_time_ms();
|
||||
const auto num_tasks = 10000 * 10;
|
||||
RAY_LOG(INFO) << "start creating " << num_tasks << " DirectActorAssignTaskRequests";
|
||||
RAY_LOG(INFO) << "start creating " << num_tasks << " PushTaskRequests";
|
||||
for (int i = 0; i < num_tasks; i++) {
|
||||
TaskOptions options{1, resources};
|
||||
TaskOptions options{1, false, resources};
|
||||
std::vector<ObjectID> return_ids;
|
||||
auto num_returns = options.num_returns;
|
||||
|
||||
TaskSpecBuilder builder;
|
||||
builder.SetCommonTaskSpec(RandomTaskId(), function.GetLanguage(),
|
||||
function.GetFunctionDescriptor(), job_id, RandomTaskId(), 0,
|
||||
RandomTaskId(), num_returns, resources, resources);
|
||||
RandomTaskId(), num_returns, /*is_direct*/ false, resources,
|
||||
resources);
|
||||
// Set task arguments.
|
||||
for (const auto &arg : args) {
|
||||
if (arg.IsPassedByReference()) {
|
||||
@@ -531,14 +521,13 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
|
||||
actor_handle.SetActorTaskSpec(builder, TaskTransportType::RAYLET,
|
||||
ObjectID::FromRandom());
|
||||
|
||||
const auto &task_spec = builder.Build();
|
||||
auto task_spec = builder.Build();
|
||||
|
||||
ASSERT_TRUE(task_spec.IsActorTask());
|
||||
auto request = std::unique_ptr<rpc::DirectActorAssignTaskRequest>(
|
||||
new rpc::DirectActorAssignTaskRequest);
|
||||
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
|
||||
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
|
||||
}
|
||||
RAY_LOG(INFO) << "Finish creating " << num_tasks << " DirectActorAssignTaskRequests"
|
||||
RAY_LOG(INFO) << "Finish creating " << num_tasks << " PushTaskRequests"
|
||||
<< ", which takes " << current_time_ms() - start_ms << " ms";
|
||||
}
|
||||
|
||||
@@ -566,7 +555,7 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) {
|
||||
sizeof(array));
|
||||
args.emplace_back(TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr)));
|
||||
|
||||
TaskOptions options{1, resources};
|
||||
TaskOptions options{1, false, resources};
|
||||
std::vector<ObjectID> return_ids;
|
||||
RayFunction func(ray::Language::PYTHON, {});
|
||||
|
||||
|
||||
@@ -0,0 +1,272 @@
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ray/common/task/task_spec.h"
|
||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||
#include "ray/core_worker/store_provider/memory_store_provider.h"
|
||||
#include "ray/core_worker/transport/direct_task_transport.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
#include "ray/rpc/worker/core_worker_client.h"
|
||||
#include "src/ray/util/test_util.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
class MockWorkerClient : public rpc::CoreWorkerClientInterface {
|
||||
public:
|
||||
ray::Status PushNormalTask(
|
||||
std::unique_ptr<rpc::PushTaskRequest> request,
|
||||
const rpc::ClientCallback<rpc::PushTaskReply> &callback) override {
|
||||
callbacks.push_back(callback);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
|
||||
};
|
||||
|
||||
class MockRayletClient : public WorkerLeaseInterface {
|
||||
public:
|
||||
ray::Status ReturnWorker(int worker_port) {
|
||||
num_workers_returned += 1;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ray::Status RequestWorkerLease(const ray::TaskSpecification &resource_spec) {
|
||||
num_workers_requested += 1;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int num_workers_requested = 0;
|
||||
int num_workers_returned = 0;
|
||||
};
|
||||
|
||||
TEST(LocalDependencyResolverTest, TestNoDependencies) {
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
CoreWorkerMemoryStoreProvider store(ptr);
|
||||
LocalDependencyResolver resolver(store);
|
||||
TaskSpecification task;
|
||||
bool ok = false;
|
||||
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
|
||||
ASSERT_TRUE(ok);
|
||||
}
|
||||
|
||||
TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) {
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
CoreWorkerMemoryStoreProvider store(ptr);
|
||||
LocalDependencyResolver resolver(store);
|
||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::RAYLET);
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
|
||||
bool ok = false;
|
||||
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
|
||||
// We ignore and don't block on plasma dependencies.
|
||||
ASSERT_TRUE(ok);
|
||||
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
||||
}
|
||||
|
||||
TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
CoreWorkerMemoryStoreProvider store(ptr);
|
||||
LocalDependencyResolver resolver(store);
|
||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
auto data = GenerateRandomObject();
|
||||
// Ensure the data is already present in the local store.
|
||||
ASSERT_TRUE(store.Put(*data, obj1).ok());
|
||||
ASSERT_TRUE(store.Put(*data, obj2).ok());
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
|
||||
task.GetMutableMessage().add_args()->add_object_ids(obj2.Binary());
|
||||
bool ok = false;
|
||||
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
|
||||
// Tests that the task proto was rewritten to have inline argument values.
|
||||
ASSERT_TRUE(ok);
|
||||
ASSERT_FALSE(task.ArgByRef(0));
|
||||
ASSERT_FALSE(task.ArgByRef(1));
|
||||
ASSERT_NE(task.ArgData(0), nullptr);
|
||||
ASSERT_NE(task.ArgData(1), nullptr);
|
||||
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
||||
}
|
||||
|
||||
TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
|
||||
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
CoreWorkerMemoryStoreProvider store(ptr);
|
||||
LocalDependencyResolver resolver(store);
|
||||
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
|
||||
auto data = GenerateRandomObject();
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
|
||||
task.GetMutableMessage().add_args()->add_object_ids(obj2.Binary());
|
||||
bool ok = false;
|
||||
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
|
||||
ASSERT_EQ(resolver.NumPendingTasks(), 1);
|
||||
ASSERT_TRUE(!ok);
|
||||
ASSERT_TRUE(store.Put(*data, obj1).ok());
|
||||
ASSERT_TRUE(store.Put(*data, obj2).ok());
|
||||
// Tests that the task proto was rewritten to have inline argument values after
|
||||
// resolution completes.
|
||||
ASSERT_TRUE(ok);
|
||||
ASSERT_FALSE(task.ArgByRef(0));
|
||||
ASSERT_FALSE(task.ArgByRef(1));
|
||||
ASSERT_NE(task.ArgData(0), nullptr);
|
||||
ASSERT_NE(task.ArgData(1), nullptr);
|
||||
ASSERT_EQ(resolver.NumPendingTasks(), 0);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTranportTest, TestSubmitOneTask) {
|
||||
MockRayletClient raylet_client;
|
||||
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
|
||||
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto factory = [&](WorkerAddress addr) { return worker_client; };
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
|
||||
ASSERT_TRUE(submitter.SubmitTask(task).ok());
|
||||
ASSERT_EQ(raylet_client.num_workers_requested, 1);
|
||||
ASSERT_EQ(raylet_client.num_workers_returned, 0);
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 0);
|
||||
|
||||
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1234));
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||
|
||||
worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply());
|
||||
ASSERT_EQ(raylet_client.num_workers_returned, 1);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTranportTest, TestHandleTaskFailure) {
|
||||
MockRayletClient raylet_client;
|
||||
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
|
||||
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto factory = [&](WorkerAddress addr) { return worker_client; };
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
|
||||
TaskSpecification task;
|
||||
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
|
||||
ASSERT_TRUE(submitter.SubmitTask(task).ok());
|
||||
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1234));
|
||||
// Simulate a system failure, i.e., worker died unexpectedly.
|
||||
worker_client->callbacks[0](Status::IOError("oops"), rpc::PushTaskReply());
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||
ASSERT_EQ(raylet_client.num_workers_returned, 1);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTranportTest, TestConcurrentWorkerLeases) {
|
||||
MockRayletClient raylet_client;
|
||||
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
|
||||
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto factory = [&](WorkerAddress addr) { return worker_client; };
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
|
||||
TaskSpecification task1;
|
||||
TaskSpecification task2;
|
||||
TaskSpecification task3;
|
||||
task1.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
task2.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
task3.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
|
||||
ASSERT_TRUE(submitter.SubmitTask(task1).ok());
|
||||
ASSERT_TRUE(submitter.SubmitTask(task2).ok());
|
||||
ASSERT_TRUE(submitter.SubmitTask(task3).ok());
|
||||
ASSERT_EQ(raylet_client.num_workers_requested, 1);
|
||||
|
||||
// Task 1 is pushed; worker 2 is requested.
|
||||
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1000));
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||
ASSERT_EQ(raylet_client.num_workers_requested, 2);
|
||||
|
||||
// Task 2 is pushed; worker 3 is requested.
|
||||
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1001));
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 2);
|
||||
ASSERT_EQ(raylet_client.num_workers_requested, 3);
|
||||
|
||||
// Task 3 is pushed; no more workers requested.
|
||||
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1002));
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 3);
|
||||
ASSERT_EQ(raylet_client.num_workers_requested, 3);
|
||||
|
||||
// All workers returned.
|
||||
for (const auto &cb : worker_client->callbacks) {
|
||||
cb(Status::OK(), rpc::PushTaskReply());
|
||||
}
|
||||
ASSERT_EQ(raylet_client.num_workers_returned, 3);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTranportTest, TestReuseWorkerLease) {
|
||||
MockRayletClient raylet_client;
|
||||
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
|
||||
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto factory = [&](WorkerAddress addr) { return worker_client; };
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
|
||||
TaskSpecification task1;
|
||||
TaskSpecification task2;
|
||||
TaskSpecification task3;
|
||||
task1.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
task2.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
task3.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
|
||||
ASSERT_TRUE(submitter.SubmitTask(task1).ok());
|
||||
ASSERT_TRUE(submitter.SubmitTask(task2).ok());
|
||||
ASSERT_TRUE(submitter.SubmitTask(task3).ok());
|
||||
ASSERT_EQ(raylet_client.num_workers_requested, 1);
|
||||
|
||||
// Task 1 is pushed.
|
||||
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1000));
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||
ASSERT_EQ(raylet_client.num_workers_requested, 2);
|
||||
|
||||
// Task 1 finishes, Task 2 is scheduled on the same worker.
|
||||
worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply());
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 2);
|
||||
ASSERT_EQ(raylet_client.num_workers_returned, 0);
|
||||
|
||||
// Task 2 finishes, Task 3 is scheduled on the same worker.
|
||||
worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply());
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 3);
|
||||
ASSERT_EQ(raylet_client.num_workers_returned, 0);
|
||||
|
||||
// Task 3 finishes, the worker is returned.
|
||||
worker_client->callbacks[2](Status::OK(), rpc::PushTaskReply());
|
||||
ASSERT_EQ(raylet_client.num_workers_returned, 1);
|
||||
|
||||
// The second lease request is returned immediately.
|
||||
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1001));
|
||||
ASSERT_EQ(raylet_client.num_workers_returned, 2);
|
||||
}
|
||||
|
||||
TEST(DirectTaskTranportTest, TestWorkerNotReusedOnError) {
|
||||
MockRayletClient raylet_client;
|
||||
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
|
||||
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
|
||||
auto factory = [&](WorkerAddress addr) { return worker_client; };
|
||||
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
|
||||
TaskSpecification task1;
|
||||
TaskSpecification task2;
|
||||
task1.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
task2.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
|
||||
|
||||
ASSERT_TRUE(submitter.SubmitTask(task1).ok());
|
||||
ASSERT_TRUE(submitter.SubmitTask(task2).ok());
|
||||
ASSERT_EQ(raylet_client.num_workers_requested, 1);
|
||||
|
||||
// Task 1 is pushed.
|
||||
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1000));
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||
ASSERT_EQ(raylet_client.num_workers_requested, 2);
|
||||
|
||||
// Task 1 finishes with failure; the worker is returned.
|
||||
worker_client->callbacks[0](Status::IOError("worker dead"), rpc::PushTaskReply());
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 1);
|
||||
ASSERT_EQ(raylet_client.num_workers_returned, 1);
|
||||
|
||||
// Task 2 runs successfully on the second worker.
|
||||
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1001));
|
||||
ASSERT_EQ(worker_client->callbacks.size(), 2);
|
||||
worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply());
|
||||
ASSERT_EQ(raylet_client.num_workers_returned, 2);
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@@ -6,14 +6,11 @@ using ray::rpc::ActorTableData;
|
||||
namespace ray {
|
||||
|
||||
CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter(
|
||||
boost::asio::io_service &io_service,
|
||||
std::unique_ptr<CoreWorkerMemoryStoreProvider> store_provider)
|
||||
: io_service_(io_service),
|
||||
client_call_manager_(io_service),
|
||||
in_memory_store_(std::move(store_provider)) {}
|
||||
rpc::ClientCallManager &client_call_manager,
|
||||
CoreWorkerMemoryStoreProvider store_provider)
|
||||
: client_call_manager_(client_call_manager), in_memory_store_(store_provider) {}
|
||||
|
||||
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(
|
||||
const TaskSpecification &task_spec) {
|
||||
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
|
||||
RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId();
|
||||
|
||||
RAY_CHECK(task_spec.IsActorTask());
|
||||
@@ -22,8 +19,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(
|
||||
const auto task_id = task_spec.TaskId();
|
||||
const auto num_returns = task_spec.NumReturns();
|
||||
|
||||
auto request = std::unique_ptr<rpc::DirectActorAssignTaskRequest>(
|
||||
new rpc::DirectActorAssignTaskRequest);
|
||||
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
|
||||
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
|
||||
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
@@ -49,7 +45,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(
|
||||
|
||||
// Submit request.
|
||||
auto &client = rpc_clients_[actor_id];
|
||||
DirectActorAssignTask(*client, std::move(request), actor_id, task_id, num_returns);
|
||||
PushActorTask(*client, std::move(request), actor_id, task_id, num_returns);
|
||||
} else {
|
||||
// Actor is dead, treat the task as failure.
|
||||
RAY_CHECK(iter->second.state_ == ActorTableData::DEAD);
|
||||
@@ -106,8 +102,8 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
|
||||
|
||||
void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks(
|
||||
const ActorID &actor_id, std::string ip_address, int port) {
|
||||
std::shared_ptr<rpc::WorkerTaskClient> grpc_client =
|
||||
std::make_shared<rpc::WorkerTaskClient>(ip_address, port, client_call_manager_);
|
||||
std::shared_ptr<rpc::CoreWorkerClient> grpc_client =
|
||||
std::make_shared<rpc::CoreWorkerClient>(ip_address, port, client_call_manager_);
|
||||
RAY_CHECK(rpc_clients_.emplace(actor_id, std::move(grpc_client)).second);
|
||||
|
||||
// Submit all pending requests.
|
||||
@@ -117,22 +113,20 @@ void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks(
|
||||
auto request = std::move(requests.front());
|
||||
auto num_returns = request->task_spec().num_returns();
|
||||
auto task_id = TaskID::FromBinary(request->task_spec().task_id());
|
||||
DirectActorAssignTask(*client, std::move(request), actor_id, task_id, num_returns);
|
||||
PushActorTask(*client, std::move(request), actor_id, task_id, num_returns);
|
||||
requests.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorkerDirectActorTaskSubmitter::DirectActorAssignTask(
|
||||
rpc::WorkerTaskClient &client,
|
||||
std::unique_ptr<rpc::DirectActorAssignTaskRequest> request, const ActorID &actor_id,
|
||||
const TaskID &task_id, int num_returns) {
|
||||
void CoreWorkerDirectActorTaskSubmitter::PushActorTask(
|
||||
rpc::CoreWorkerClient &client, std::unique_ptr<rpc::PushTaskRequest> request,
|
||||
const ActorID &actor_id, const TaskID &task_id, int num_returns) {
|
||||
RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id;
|
||||
waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns));
|
||||
|
||||
auto status = client.DirectActorAssignTask(
|
||||
std::move(request),
|
||||
[this, actor_id, task_id, num_returns](
|
||||
Status status, const rpc::DirectActorAssignTaskReply &reply) {
|
||||
auto status = client.PushActorTask(
|
||||
std::move(request), [this, actor_id, task_id, num_returns](
|
||||
Status status, const rpc::PushTaskReply &reply) {
|
||||
{
|
||||
std::unique_lock<std::mutex> guard(mutex_);
|
||||
waiting_reply_tasks_[actor_id].erase(task_id);
|
||||
@@ -156,7 +150,7 @@ void CoreWorkerDirectActorTaskSubmitter::DirectActorAssignTask(
|
||||
const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
||||
RAY_CHECK_OK(
|
||||
in_memory_store_->Put(RayObject(nullptr, meta_buffer), object_id));
|
||||
in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id));
|
||||
} else {
|
||||
std::shared_ptr<LocalMemoryBuffer> data_buffer;
|
||||
if (return_object.data().size() > 0) {
|
||||
@@ -172,8 +166,8 @@ void CoreWorkerDirectActorTaskSubmitter::DirectActorAssignTask(
|
||||
reinterpret_cast<const uint8_t *>(return_object.metadata().data())),
|
||||
return_object.metadata().size());
|
||||
}
|
||||
RAY_CHECK_OK(in_memory_store_->Put(RayObject(data_buffer, metadata_buffer),
|
||||
object_id));
|
||||
RAY_CHECK_OK(
|
||||
in_memory_store_.Put(RayObject(data_buffer, metadata_buffer), object_id));
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -189,11 +183,11 @@ void CoreWorkerDirectActorTaskSubmitter::TreatTaskAsFailed(
|
||||
for (int i = 0; i < num_returns; i++) {
|
||||
const auto object_id = ObjectID::ForTaskReturn(
|
||||
task_id, /*index=*/i + 1,
|
||||
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT_ACTOR));
|
||||
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
|
||||
std::string meta = std::to_string(static_cast<int>(error_type));
|
||||
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
||||
RAY_CHECK_OK(in_memory_store_->Put(RayObject(nullptr, meta_buffer), object_id));
|
||||
RAY_CHECK_OK(in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,20 +198,19 @@ bool CoreWorkerDirectActorTaskSubmitter::IsActorAlive(const ActorID &actor_id) c
|
||||
return (iter != actor_states_.end() && iter->second.state_ == ActorTableData::ALIVE);
|
||||
}
|
||||
|
||||
CoreWorkerDirectActorTaskReceiver::CoreWorkerDirectActorTaskReceiver(
|
||||
CoreWorkerDirectTaskReceiver::CoreWorkerDirectTaskReceiver(
|
||||
WorkerContext &worker_context, boost::asio::io_service &main_io_service,
|
||||
rpc::GrpcServer &server, const TaskHandler &task_handler,
|
||||
const std::function<void()> &exit_handler)
|
||||
const TaskHandler &task_handler, const std::function<void()> &exit_handler)
|
||||
: worker_context_(worker_context),
|
||||
task_handler_(task_handler),
|
||||
exit_handler_(exit_handler),
|
||||
task_main_io_service_(main_io_service) {}
|
||||
|
||||
void CoreWorkerDirectActorTaskReceiver::Init(RayletClient &raylet_client) {
|
||||
void CoreWorkerDirectTaskReceiver::Init(RayletClient &raylet_client) {
|
||||
waiter_.reset(new DependencyWaiterImpl(raylet_client));
|
||||
}
|
||||
|
||||
void CoreWorkerDirectActorTaskReceiver::SetMaxActorConcurrency(int max_concurrency) {
|
||||
void CoreWorkerDirectTaskReceiver::SetMaxActorConcurrency(int max_concurrency) {
|
||||
if (max_concurrency != max_concurrency_) {
|
||||
RAY_LOG(INFO) << "Creating new thread pool of size " << max_concurrency;
|
||||
RAY_CHECK(pool_ == nullptr) << "Cannot change max concurrency at runtime.";
|
||||
@@ -226,13 +219,13 @@ void CoreWorkerDirectActorTaskReceiver::SetMaxActorConcurrency(int max_concurren
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorkerDirectActorTaskReceiver::HandleDirectActorAssignTask(
|
||||
const rpc::DirectActorAssignTaskRequest &request,
|
||||
rpc::DirectActorAssignTaskReply *reply, rpc::SendReplyCallback send_reply_callback) {
|
||||
void CoreWorkerDirectTaskReceiver::HandlePushTask(
|
||||
const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
RAY_CHECK(waiter_ != nullptr) << "Must call init() prior to use";
|
||||
const TaskSpecification task_spec(request.task_spec());
|
||||
RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId();
|
||||
if (task_spec.IsActorTask() && !worker_context_.CurrentActorUseDirectCall()) {
|
||||
if (task_spec.IsActorTask() && !worker_context_.CurrentTaskIsDirectCall()) {
|
||||
send_reply_callback(Status::Invalid("This actor doesn't accept direct calls."),
|
||||
nullptr, nullptr);
|
||||
return;
|
||||
@@ -257,62 +250,63 @@ void CoreWorkerDirectActorTaskReceiver::HandleDirectActorAssignTask(
|
||||
task_main_io_service_, *waiter_, pool_)));
|
||||
it = result.first;
|
||||
}
|
||||
it->second->Add(
|
||||
request.sequence_number(), request.client_processed_up_to(),
|
||||
[this, reply, send_reply_callback, task_spec]() {
|
||||
auto num_returns = task_spec.NumReturns();
|
||||
RAY_CHECK(task_spec.IsActorCreationTask() || task_spec.IsActorTask());
|
||||
RAY_CHECK(num_returns > 0);
|
||||
// Decrease to account for the dummy object id.
|
||||
num_returns--;
|
||||
it->second->Add(request.sequence_number(), request.client_processed_up_to(),
|
||||
[this, reply, send_reply_callback, task_spec]() {
|
||||
auto num_returns = task_spec.NumReturns();
|
||||
RAY_CHECK(num_returns > 0);
|
||||
if (task_spec.IsActorCreationTask() || task_spec.IsActorTask()) {
|
||||
// Decrease to account for the dummy object id.
|
||||
num_returns--;
|
||||
}
|
||||
|
||||
// TODO(edoakes): resource IDs are currently kept track of in the raylet,
|
||||
// need to come up with a solution for this.
|
||||
ResourceMappingType resource_ids;
|
||||
std::vector<std::shared_ptr<RayObject>> return_objects;
|
||||
auto status = task_handler_(task_spec, resource_ids, &return_objects);
|
||||
if (status.IsSystemExit()) {
|
||||
// In Python, SystemExit can only be raised on the main thread. To work
|
||||
// around this when we are executing tasks on worker threads, we re-post the
|
||||
// exit event explicitly on the main thread.
|
||||
task_main_io_service_.post([this]() { exit_handler_(); });
|
||||
return;
|
||||
}
|
||||
RAY_CHECK(return_objects.size() == num_returns)
|
||||
<< return_objects.size() << " " << num_returns;
|
||||
// TODO(edoakes): resource IDs are currently kept track of in the
|
||||
// raylet, need to come up with a solution for this.
|
||||
ResourceMappingType resource_ids;
|
||||
std::vector<std::shared_ptr<RayObject>> return_objects;
|
||||
auto status = task_handler_(task_spec, resource_ids, &return_objects);
|
||||
if (status.IsSystemExit()) {
|
||||
// In Python, SystemExit can only be raised on the main thread. To
|
||||
// work around this when we are executing tasks on worker threads,
|
||||
// we re-post the exit event explicitly on the main thread.
|
||||
task_main_io_service_.post([this]() { exit_handler_(); });
|
||||
return;
|
||||
}
|
||||
RAY_CHECK(return_objects.size() == num_returns)
|
||||
<< return_objects.size() << " " << num_returns;
|
||||
|
||||
for (size_t i = 0; i < return_objects.size(); i++) {
|
||||
auto return_object = reply->add_return_objects();
|
||||
ObjectID id = ObjectID::ForTaskReturn(
|
||||
task_spec.TaskId(), /*index=*/i + 1,
|
||||
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT_ACTOR));
|
||||
return_object->set_object_id(id.Binary());
|
||||
for (size_t i = 0; i < return_objects.size(); i++) {
|
||||
auto return_object = reply->add_return_objects();
|
||||
ObjectID id = ObjectID::ForTaskReturn(
|
||||
task_spec.TaskId(), /*index=*/i + 1,
|
||||
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
|
||||
return_object->set_object_id(id.Binary());
|
||||
|
||||
// The object is nullptr if it already existed in the object store.
|
||||
const auto &result = return_objects[i];
|
||||
if (result == nullptr || result->GetData()->IsPlasmaBuffer()) {
|
||||
return_object->set_in_plasma(true);
|
||||
} else {
|
||||
if (result->GetData() != nullptr) {
|
||||
return_object->set_data(result->GetData()->Data(),
|
||||
result->GetData()->Size());
|
||||
}
|
||||
if (result->GetMetadata() != nullptr) {
|
||||
return_object->set_metadata(result->GetMetadata()->Data(),
|
||||
result->GetMetadata()->Size());
|
||||
}
|
||||
}
|
||||
}
|
||||
// The object is nullptr if it already existed in the object store.
|
||||
const auto &result = return_objects[i];
|
||||
if (result == nullptr || result->GetData()->IsPlasmaBuffer()) {
|
||||
return_object->set_in_plasma(true);
|
||||
} else {
|
||||
if (result->GetData() != nullptr) {
|
||||
return_object->set_data(result->GetData()->Data(),
|
||||
result->GetData()->Size());
|
||||
}
|
||||
if (result->GetMetadata() != nullptr) {
|
||||
return_object->set_metadata(result->GetMetadata()->Data(),
|
||||
result->GetMetadata()->Size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
send_reply_callback(status, nullptr, nullptr);
|
||||
},
|
||||
[send_reply_callback]() {
|
||||
send_reply_callback(Status::Invalid("client cancelled rpc"), nullptr, nullptr);
|
||||
},
|
||||
dependencies);
|
||||
send_reply_callback(status, nullptr, nullptr);
|
||||
},
|
||||
[send_reply_callback]() {
|
||||
send_reply_callback(Status::Invalid("client cancelled rpc"), nullptr,
|
||||
nullptr);
|
||||
},
|
||||
dependencies);
|
||||
}
|
||||
|
||||
void CoreWorkerDirectActorTaskReceiver::HandleDirectActorCallArgWaitComplete(
|
||||
void CoreWorkerDirectTaskReceiver::HandleDirectActorCallArgWaitComplete(
|
||||
const rpc::DirectActorCallArgWaitCompleteRequest &request,
|
||||
rpc::DirectActorCallArgWaitCompleteReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback) {
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#include "ray/core_worker/store_provider/memory_store_provider.h"
|
||||
#include "ray/gcs/redis_gcs_client.h"
|
||||
#include "ray/rpc/grpc_server.h"
|
||||
#include "ray/rpc/worker/worker_client.h"
|
||||
#include "ray/rpc/worker/core_worker_client.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
@@ -40,15 +40,14 @@ struct ActorStateData {
|
||||
// This class is thread-safe.
|
||||
class CoreWorkerDirectActorTaskSubmitter {
|
||||
public:
|
||||
CoreWorkerDirectActorTaskSubmitter(
|
||||
boost::asio::io_service &io_service,
|
||||
std::unique_ptr<CoreWorkerMemoryStoreProvider> store_provider);
|
||||
CoreWorkerDirectActorTaskSubmitter(rpc::ClientCallManager &client_call_manager,
|
||||
CoreWorkerMemoryStoreProvider store_provider);
|
||||
|
||||
/// Submit a task to an actor for execution.
|
||||
///
|
||||
/// \param[in] task The task spec to submit.
|
||||
/// \return Status::Invalid if the task is not yet supported.
|
||||
Status SubmitTask(const TaskSpecification &task_spec);
|
||||
Status SubmitTask(TaskSpecification task_spec);
|
||||
|
||||
/// Handle an update about an actor.
|
||||
///
|
||||
@@ -67,10 +66,9 @@ class CoreWorkerDirectActorTaskSubmitter {
|
||||
/// \param[in] task_id The ID of a task.
|
||||
/// \param[in] num_returns Number of return objects.
|
||||
/// \return Void.
|
||||
void DirectActorAssignTask(rpc::WorkerTaskClient &client,
|
||||
std::unique_ptr<rpc::DirectActorAssignTaskRequest> request,
|
||||
const ActorID &actor_id, const TaskID &task_id,
|
||||
int num_returns);
|
||||
void PushActorTask(rpc::CoreWorkerClient &client,
|
||||
std::unique_ptr<rpc::PushTaskRequest> request,
|
||||
const ActorID &actor_id, const TaskID &task_id, int num_returns);
|
||||
|
||||
/// Treat a task as failed.
|
||||
///
|
||||
@@ -98,11 +96,8 @@ class CoreWorkerDirectActorTaskSubmitter {
|
||||
/// \return Whether this actor is alive.
|
||||
bool IsActorAlive(const ActorID &actor_id) const;
|
||||
|
||||
/// The IO event loop.
|
||||
boost::asio::io_service &io_service_;
|
||||
|
||||
/// The `ClientCallManager` object that is shared by all `DirectActorClient`s.
|
||||
rpc::ClientCallManager client_call_manager_;
|
||||
/// The shared `ClientCallManager` object.
|
||||
rpc::ClientCallManager &client_call_manager_;
|
||||
|
||||
/// Mutex to proect the various maps below.
|
||||
mutable std::mutex mutex_;
|
||||
@@ -115,18 +110,17 @@ class CoreWorkerDirectActorTaskSubmitter {
|
||||
///
|
||||
/// TODO(zhijunfu): this will be moved into `actor_states_` later when we can
|
||||
/// subscribe updates for a specific actor.
|
||||
std::unordered_map<ActorID, std::shared_ptr<rpc::WorkerTaskClient>> rpc_clients_;
|
||||
std::unordered_map<ActorID, std::shared_ptr<rpc::CoreWorkerClient>> rpc_clients_;
|
||||
|
||||
/// Map from actor id to the actor's pending requests.
|
||||
std::unordered_map<ActorID,
|
||||
std::list<std::unique_ptr<rpc::DirectActorAssignTaskRequest>>>
|
||||
std::unordered_map<ActorID, std::list<std::unique_ptr<rpc::PushTaskRequest>>>
|
||||
pending_requests_;
|
||||
|
||||
/// Map from actor id to the tasks that are waiting for reply.
|
||||
std::unordered_map<ActorID, std::unordered_map<TaskID, int>> waiting_reply_tasks_;
|
||||
|
||||
/// The store provider.
|
||||
std::unique_ptr<CoreWorkerMemoryStoreProvider> in_memory_store_;
|
||||
CoreWorkerMemoryStoreProvider in_memory_store_;
|
||||
|
||||
friend class CoreWorkerTest;
|
||||
};
|
||||
@@ -235,6 +229,9 @@ class SchedulingQueue {
|
||||
void Add(int64_t seq_no, int64_t client_processed_up_to,
|
||||
std::function<void()> accept_request, std::function<void()> reject_request,
|
||||
const std::vector<ObjectID> &dependencies = {}) {
|
||||
if (seq_no == -1) {
|
||||
seq_no = next_seq_no_; // A value of -1 means no ordering constraint.
|
||||
}
|
||||
RAY_CHECK(boost::this_thread::get_id() == main_thread_id_);
|
||||
if (client_processed_up_to >= next_seq_no_) {
|
||||
RAY_LOG(ERROR) << "client skipping requests " << next_seq_no_ << " to "
|
||||
@@ -329,29 +326,27 @@ class SchedulingQueue {
|
||||
friend class SchedulingQueueTest;
|
||||
};
|
||||
|
||||
class CoreWorkerDirectActorTaskReceiver {
|
||||
class CoreWorkerDirectTaskReceiver {
|
||||
public:
|
||||
using TaskHandler = std::function<Status(
|
||||
const TaskSpecification &task_spec, const ResourceMappingType &resource_ids,
|
||||
std::vector<std::shared_ptr<RayObject>> *return_objects)>;
|
||||
|
||||
CoreWorkerDirectActorTaskReceiver(WorkerContext &worker_context,
|
||||
boost::asio::io_service &main_io_service,
|
||||
rpc::GrpcServer &server,
|
||||
const TaskHandler &task_handler,
|
||||
const std::function<void()> &exit_handler);
|
||||
CoreWorkerDirectTaskReceiver(WorkerContext &worker_context,
|
||||
boost::asio::io_service &main_io_service,
|
||||
const TaskHandler &task_handler,
|
||||
const std::function<void()> &exit_handler);
|
||||
|
||||
/// Initialize this receiver. This must be called prior to use.
|
||||
void Init(RayletClient &client);
|
||||
|
||||
/// Handle a `DirectActorAssignTask` request.
|
||||
/// Handle a `PushTask` request.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[out] reply The reply message.
|
||||
/// \param[in] send_reply_callback The callback to be called when the request is done.
|
||||
void HandleDirectActorAssignTask(const rpc::DirectActorAssignTaskRequest &request,
|
||||
rpc::DirectActorAssignTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback);
|
||||
void HandlePushTask(const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
|
||||
rpc::SendReplyCallback send_reply_callback);
|
||||
|
||||
/// Handle a `DirectActorCallArgWaitComplete` request.
|
||||
///
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
#include "ray/core_worker/transport/direct_task_transport.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
void DoInlineObjectValue(const ObjectID &obj_id, std::shared_ptr<RayObject> value,
|
||||
TaskSpecification &task) {
|
||||
auto &msg = task.GetMutableMessage();
|
||||
bool found = false;
|
||||
for (size_t i = 0; i < task.NumArgs(); i++) {
|
||||
auto count = task.ArgIdCount(i);
|
||||
if (count > 0) {
|
||||
const auto &id = task.ArgId(i, 0);
|
||||
if (id == obj_id) {
|
||||
auto *mutable_arg = msg.mutable_args(i);
|
||||
mutable_arg->clear_object_ids();
|
||||
if (value->HasData()) {
|
||||
const auto &data = value->GetData();
|
||||
mutable_arg->set_data(data->Data(), data->Size());
|
||||
}
|
||||
if (value->HasMetadata()) {
|
||||
const auto &metadata = value->GetMetadata();
|
||||
mutable_arg->set_metadata(metadata->Data(), metadata->Size());
|
||||
}
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
RAY_CHECK(found) << "obj id " << obj_id << " not found";
|
||||
}
|
||||
|
||||
void LocalDependencyResolver::ResolveDependencies(const TaskSpecification &task,
|
||||
std::function<void()> on_complete) {
|
||||
absl::flat_hash_set<ObjectID> local_dependencies;
|
||||
for (size_t i = 0; i < task.NumArgs(); i++) {
|
||||
auto count = task.ArgIdCount(i);
|
||||
if (count > 0) {
|
||||
RAY_CHECK(count <= 1) << "multi args not implemented";
|
||||
const auto &id = task.ArgId(i, 0);
|
||||
if (id.IsDirectCallType()) {
|
||||
local_dependencies.insert(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (local_dependencies.empty()) {
|
||||
on_complete();
|
||||
return;
|
||||
}
|
||||
|
||||
// This is deleted when the last dependency fetch callback finishes.
|
||||
std::shared_ptr<TaskState> state =
|
||||
std::shared_ptr<TaskState>(new TaskState{task, std::move(local_dependencies)});
|
||||
num_pending_ += 1;
|
||||
|
||||
for (const auto &obj_id : state->local_dependencies) {
|
||||
in_memory_store_.GetAsync(
|
||||
obj_id, [this, state, obj_id, on_complete](std::shared_ptr<RayObject> obj) {
|
||||
RAY_CHECK(obj != nullptr);
|
||||
bool complete = false;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
state->local_dependencies.erase(obj_id);
|
||||
DoInlineObjectValue(obj_id, obj, state->task);
|
||||
if (state->local_dependencies.empty()) {
|
||||
complete = true;
|
||||
num_pending_ -= 1;
|
||||
}
|
||||
}
|
||||
if (complete) {
|
||||
on_complete();
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
|
||||
resolver_.ResolveDependencies(task_spec, [this, task_spec]() {
|
||||
// TODO(ekl) should have a queue per distinct resource type required
|
||||
absl::MutexLock lock(&mu_);
|
||||
RequestNewWorkerIfNeeded(task_spec);
|
||||
queued_tasks_.push_back(task_spec);
|
||||
// The task is now queued and will be picked up by the next leased or newly
|
||||
// idle worker. We are guaranteed a worker will show up since we called
|
||||
// RequestNewWorkerIfNeeded() earlier while holding mu_.
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void CoreWorkerDirectTaskSubmitter::HandleWorkerLeaseGranted(const WorkerAddress addr) {
|
||||
// Setup client state for this worker.
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
worker_request_pending_ = false;
|
||||
|
||||
auto it = client_cache_.find(addr);
|
||||
if (it == client_cache_.end()) {
|
||||
client_cache_[addr] =
|
||||
std::shared_ptr<rpc::CoreWorkerClientInterface>(client_factory_(addr));
|
||||
RAY_LOG(INFO) << "Connected to " << addr.first << ":" << addr.second;
|
||||
}
|
||||
}
|
||||
|
||||
// Try to assign it work.
|
||||
OnWorkerIdle(addr, /*error=*/false);
|
||||
}
|
||||
|
||||
void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(const WorkerAddress &addr,
|
||||
bool was_error) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
if (queued_tasks_.empty() || was_error) {
|
||||
RAY_CHECK_OK(lease_client_.ReturnWorker(addr.second));
|
||||
} else {
|
||||
auto &client = *client_cache_[addr];
|
||||
PushNormalTask(addr, client, queued_tasks_.front());
|
||||
queued_tasks_.pop_front();
|
||||
}
|
||||
// We have a queue of tasks, try to request more workers.
|
||||
if (!queued_tasks_.empty()) {
|
||||
RequestNewWorkerIfNeeded(queued_tasks_.front());
|
||||
}
|
||||
}
|
||||
|
||||
void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded(
|
||||
const TaskSpecification &resource_spec) {
|
||||
if (worker_request_pending_) {
|
||||
return;
|
||||
}
|
||||
RAY_CHECK_OK(lease_client_.RequestWorkerLease(resource_spec));
|
||||
worker_request_pending_ = true;
|
||||
}
|
||||
|
||||
void CoreWorkerDirectTaskSubmitter::TreatTaskAsFailed(const TaskID &task_id,
|
||||
int num_returns,
|
||||
const rpc::ErrorType &error_type) {
|
||||
RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id
|
||||
<< ", error_type: " << ErrorType_Name(error_type);
|
||||
for (int i = 0; i < num_returns; i++) {
|
||||
const auto object_id = ObjectID::ForTaskReturn(
|
||||
task_id, /*index=*/i + 1,
|
||||
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
|
||||
std::string meta = std::to_string(static_cast<int>(error_type));
|
||||
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
|
||||
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
|
||||
RAY_CHECK_OK(in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(ekl) consider reconsolidating with DirectActorTransport.
|
||||
void CoreWorkerDirectTaskSubmitter::PushNormalTask(const WorkerAddress &addr,
|
||||
rpc::CoreWorkerClientInterface &client,
|
||||
TaskSpecification &task_spec) {
|
||||
auto task_id = task_spec.TaskId();
|
||||
auto num_returns = task_spec.NumReturns();
|
||||
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
|
||||
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
|
||||
auto status = client.PushNormalTask(
|
||||
std::move(request),
|
||||
[this, task_id, num_returns, addr](Status status, const rpc::PushTaskReply &reply) {
|
||||
OnWorkerIdle(addr, /*error=*/!status.ok());
|
||||
if (!status.ok()) {
|
||||
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED);
|
||||
return;
|
||||
}
|
||||
for (int i = 0; i < reply.return_objects_size(); i++) {
|
||||
const auto &return_object = reply.return_objects(i);
|
||||
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
|
||||
std::shared_ptr<LocalMemoryBuffer> data_buffer;
|
||||
if (return_object.data().size() > 0) {
|
||||
data_buffer = std::make_shared<LocalMemoryBuffer>(
|
||||
const_cast<uint8_t *>(
|
||||
reinterpret_cast<const uint8_t *>(return_object.data().data())),
|
||||
return_object.data().size());
|
||||
}
|
||||
std::shared_ptr<LocalMemoryBuffer> metadata_buffer;
|
||||
if (return_object.metadata().size() > 0) {
|
||||
metadata_buffer = std::make_shared<LocalMemoryBuffer>(
|
||||
const_cast<uint8_t *>(
|
||||
reinterpret_cast<const uint8_t *>(return_object.metadata().data())),
|
||||
return_object.metadata().size());
|
||||
}
|
||||
RAY_CHECK_OK(
|
||||
in_memory_store_.Put(RayObject(data_buffer, metadata_buffer), object_id));
|
||||
}
|
||||
});
|
||||
RAY_CHECK_OK(status);
|
||||
}
|
||||
}; // namespace ray
|
||||
@@ -0,0 +1,128 @@
|
||||
#ifndef RAY_CORE_WORKER_DIRECT_TASK_H
|
||||
#define RAY_CORE_WORKER_DIRECT_TASK_H
|
||||
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "ray/common/id.h"
|
||||
#include "ray/common/ray_object.h"
|
||||
#include "ray/core_worker/context.h"
|
||||
#include "ray/core_worker/store_provider/memory_store_provider.h"
|
||||
#include "ray/core_worker/transport/direct_actor_transport.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
#include "ray/rpc/worker/core_worker_client.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
struct TaskState {
|
||||
/// The task to be run.
|
||||
TaskSpecification task;
|
||||
/// The remaining dependencies to resolve for this task.
|
||||
absl::flat_hash_set<ObjectID> local_dependencies;
|
||||
};
|
||||
|
||||
// This class is thread-safe.
|
||||
class LocalDependencyResolver {
|
||||
public:
|
||||
LocalDependencyResolver(CoreWorkerMemoryStoreProvider &store_provider)
|
||||
: in_memory_store_(store_provider), num_pending_(0) {}
|
||||
|
||||
/// Resolve all local and remote dependencies for the task, calling the specified
|
||||
/// callback when done. Direct call ids in the task specification will be resolved
|
||||
/// to concrete values and inlined.
|
||||
//
|
||||
/// Note: This method **will mutate** the given TaskSpecification.
|
||||
///
|
||||
/// Postcondition: all direct call ids in arguments are converted to values.
|
||||
void ResolveDependencies(const TaskSpecification &task,
|
||||
std::function<void()> on_complete);
|
||||
|
||||
/// Return the number of tasks pending dependency resolution.
|
||||
/// TODO(ekl) this should be exposed in worker stats.
|
||||
int NumPendingTasks() const { return num_pending_; }
|
||||
|
||||
private:
|
||||
/// The store provider.
|
||||
CoreWorkerMemoryStoreProvider in_memory_store_;
|
||||
|
||||
/// Number of tasks pending dependency resolution.
|
||||
std::atomic<int> num_pending_;
|
||||
|
||||
/// Protects against concurrent access to internal state.
|
||||
absl::Mutex mu_;
|
||||
};
|
||||
|
||||
typedef std::pair<std::string, int> WorkerAddress;
|
||||
typedef std::function<std::shared_ptr<rpc::CoreWorkerClientInterface>(WorkerAddress)>
|
||||
ClientFactoryFn;
|
||||
|
||||
// This class is thread-safe.
|
||||
class CoreWorkerDirectTaskSubmitter {
|
||||
public:
|
||||
CoreWorkerDirectTaskSubmitter(WorkerLeaseInterface &lease_client,
|
||||
ClientFactoryFn client_factory,
|
||||
CoreWorkerMemoryStoreProvider store_provider)
|
||||
: lease_client_(lease_client),
|
||||
client_factory_(client_factory),
|
||||
in_memory_store_(store_provider),
|
||||
resolver_(in_memory_store_) {}
|
||||
|
||||
/// Schedule a task for direct submission to a worker.
|
||||
///
|
||||
/// \param[in] task_spec The task to schedule.
|
||||
Status SubmitTask(TaskSpecification task_spec);
|
||||
|
||||
/// Callback for when the raylet grants us a worker lease. The worker is returned
|
||||
/// to the raylet once it finishes its task and either the lease term has
|
||||
/// expired, or there is no more work it can take on.
|
||||
///
|
||||
/// \param[in] addr The (addr, port) pair identifying the worker.
|
||||
void HandleWorkerLeaseGranted(const WorkerAddress addr);
|
||||
|
||||
private:
|
||||
/// Schedule more work onto an idle worker or return it back to the raylet if
|
||||
/// no more tasks are queued for submission. If an error was encountered
|
||||
/// processing the worker, we don't attempt to re-use the worker.
|
||||
void OnWorkerIdle(const WorkerAddress &addr, bool was_error);
|
||||
|
||||
/// Request a new worker from the raylet if no such requests are currently in
|
||||
/// flight.
|
||||
void RequestNewWorkerIfNeeded(const TaskSpecification &resource_spec)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
/// Push a task to a specific worker.
|
||||
void PushNormalTask(const WorkerAddress &addr, rpc::CoreWorkerClientInterface &client,
|
||||
TaskSpecification &task_spec);
|
||||
|
||||
/// Mark a direct call as failed by storing errors for its return objects.
|
||||
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
|
||||
const rpc::ErrorType &error_type);
|
||||
|
||||
// Client that can be used to lease and return workers.
|
||||
WorkerLeaseInterface &lease_client_;
|
||||
|
||||
/// Factory for producing new core worker clients.
|
||||
ClientFactoryFn client_factory_;
|
||||
|
||||
/// The store provider.
|
||||
CoreWorkerMemoryStoreProvider in_memory_store_;
|
||||
|
||||
/// Resolve local and remote dependencies;
|
||||
LocalDependencyResolver resolver_;
|
||||
|
||||
// Protects task submission state below.
|
||||
absl::Mutex mu_;
|
||||
|
||||
/// Cache of gRPC clients to other workers.
|
||||
absl::flat_hash_map<WorkerAddress, std::shared_ptr<rpc::CoreWorkerClientInterface>>
|
||||
client_cache_ GUARDED_BY(mu_);
|
||||
|
||||
// Whether we have a request to the Raylet to acquire a new worker in flight.
|
||||
bool worker_request_pending_ GUARDED_BY(mu_) = false;
|
||||
|
||||
// Tasks that are queued for execution in this submitter..
|
||||
std::deque<TaskSpecification> queued_tasks_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
}; // namespace ray
|
||||
|
||||
#endif // RAY_CORE_WORKER_DIRECT_TASK_H
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "ray/common/ray_object.h"
|
||||
#include "ray/raylet/raylet_client.h"
|
||||
#include "ray/rpc/worker/worker_server.h"
|
||||
#include "ray/rpc/worker/core_worker_server.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
|
||||
@@ -66,6 +66,8 @@ message TaskSpec {
|
||||
// Task specification for an actor task.
|
||||
// This field is only valid when `type == ACTOR_TASK`.
|
||||
ActorTaskSpec actor_task_spec = 14;
|
||||
// Whether this task is a direct call task.
|
||||
bool is_direct_call = 15;
|
||||
}
|
||||
|
||||
// Argument in the task.
|
||||
|
||||
@@ -56,22 +56,24 @@ message ReturnObject {
|
||||
bytes metadata = 4;
|
||||
}
|
||||
|
||||
message DirectActorAssignTaskRequest {
|
||||
message PushTaskRequest {
|
||||
// The task to be pushed.
|
||||
TaskSpec task_spec = 1;
|
||||
// The sequence number of the task for this client. This must increase
|
||||
// sequentially starting from zero for each actor handle. The server
|
||||
// will guarantee tasks execute in this sequence, waiting for any
|
||||
// out-of-order request messages to arrive as necessary.
|
||||
// If set to -1, ordering is disabled and the task executes immediately.
|
||||
// This mode of behaviour is used for direct task submission only.
|
||||
int64 sequence_number = 2;
|
||||
// The max sequence number the client has processed responses for. This
|
||||
// is a performance optimization that allows the client to tell the server
|
||||
// to cancel any DirectActorAssignTaskRequests with seqno <= this value, rather than
|
||||
// to cancel any PushTaskRequests with seqno <= this value, rather than
|
||||
// waiting for the server to time out waiting for missing messages.
|
||||
int64 client_processed_up_to = 3;
|
||||
}
|
||||
|
||||
message DirectActorAssignTaskReply {
|
||||
message PushTaskReply {
|
||||
// The returned objects.
|
||||
repeated ReturnObject return_objects = 1;
|
||||
}
|
||||
@@ -85,13 +87,24 @@ message DirectActorCallArgWaitCompleteRequest {
|
||||
message DirectActorCallArgWaitCompleteReply {
|
||||
}
|
||||
|
||||
service WorkerService {
|
||||
// Push a task to a worker.
|
||||
message WorkerLeaseGrantedRequest {
|
||||
// Address of the leased worker.
|
||||
string address = 1;
|
||||
// Port of the leased worker.
|
||||
int32 port = 2;
|
||||
}
|
||||
|
||||
message WorkerLeaseGrantedReply {
|
||||
}
|
||||
|
||||
service CoreWorkerService {
|
||||
// Push a task to a worker from the raylet.
|
||||
rpc AssignTask(AssignTaskRequest) returns (AssignTaskReply);
|
||||
// Push a task to a direct-call actor.
|
||||
rpc DirectActorAssignTask(DirectActorAssignTaskRequest)
|
||||
returns (DirectActorAssignTaskReply);
|
||||
// Notify wait for direct actor call args has completed.
|
||||
// Push a task directly to this worker from another.
|
||||
rpc PushTask(PushTaskRequest) returns (PushTaskReply);
|
||||
// Reply from raylet that wait for direct actor call args has completed.
|
||||
rpc DirectActorCallArgWaitComplete(DirectActorCallArgWaitCompleteRequest)
|
||||
returns (DirectActorCallArgWaitCompleteReply);
|
||||
// Reply from raylet to fulfill a worker lease request.
|
||||
rpc WorkerLeaseGranted(WorkerLeaseGrantedRequest) returns (WorkerLeaseGrantedReply);
|
||||
}
|
||||
|
||||
@@ -74,6 +74,10 @@ enum MessageType:int {
|
||||
SetResourceRequest,
|
||||
// Update the active set of object IDs in use on this worker.
|
||||
ReportActiveObjectIDs,
|
||||
// Request a worker from the raylet with the specified resources.
|
||||
RequestWorkerLease,
|
||||
// Returns a worker to the raylet.
|
||||
ReturnWorker,
|
||||
}
|
||||
|
||||
table TaskExecutionSpecification {
|
||||
@@ -91,6 +95,14 @@ table Task {
|
||||
task_execution_spec: TaskExecutionSpecification;
|
||||
}
|
||||
|
||||
table WorkerLeaseRequest {
|
||||
resource_spec: string;
|
||||
}
|
||||
|
||||
table ReturnWorkerRequest {
|
||||
worker_port: int;
|
||||
}
|
||||
|
||||
// This message describes a given resource that is reserved for a worker.
|
||||
table ResourceIdSetInfo {
|
||||
// The name of the resource.
|
||||
|
||||
@@ -139,7 +139,8 @@ static inline Task ExampleTask(const std::vector<ObjectID> &arguments,
|
||||
uint64_t num_returns) {
|
||||
TaskSpecBuilder builder;
|
||||
builder.SetCommonTaskSpec(RandomTaskId(), Language::PYTHON, {"", "", ""}, JobID::Nil(),
|
||||
RandomTaskId(), 0, RandomTaskId(), num_returns, {}, {});
|
||||
RandomTaskId(), 0, RandomTaskId(), num_returns, false, {},
|
||||
{});
|
||||
for (const auto &arg : arguments) {
|
||||
builder.AddByRefArg(arg);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <cctype>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
|
||||
#include "ray/common/status.h"
|
||||
|
||||
@@ -895,6 +896,12 @@ void NodeManager::ProcessClientMessage(
|
||||
// because it's already disconnected.
|
||||
return;
|
||||
} break;
|
||||
case protocol::MessageType::RequestWorkerLease: {
|
||||
ProcessRequestWorkerLeaseMessage(client, message_data);
|
||||
} break;
|
||||
case protocol::MessageType::ReturnWorker: {
|
||||
ProcessReturnWorkerMessage(message_data);
|
||||
} break;
|
||||
case protocol::MessageType::SetResourceRequest: {
|
||||
ProcessSetResourceRequest(client, message_data);
|
||||
} break;
|
||||
@@ -1031,6 +1038,10 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca
|
||||
void NodeManager::HandleWorkerAvailable(
|
||||
const std::shared_ptr<LocalClientConnection> &client) {
|
||||
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
|
||||
HandleWorkerAvailable(worker);
|
||||
}
|
||||
|
||||
void NodeManager::HandleWorkerAvailable(const std::shared_ptr<Worker> &worker) {
|
||||
RAY_CHECK(worker);
|
||||
// If the worker was assigned a task, mark it as finished.
|
||||
if (!worker->GetAssignedTaskId().IsNil()) {
|
||||
@@ -1172,6 +1183,43 @@ void NodeManager::ProcessDisconnectClientMessage(
|
||||
// these can be leaked.
|
||||
}
|
||||
|
||||
void NodeManager::ProcessRequestWorkerLeaseMessage(
|
||||
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data) {
|
||||
// Read the resource spec submitted by the client.
|
||||
auto fbs_message = flatbuffers::GetRoot<protocol::WorkerLeaseRequest>(message_data);
|
||||
rpc::Task task_message;
|
||||
RAY_CHECK(task_message.mutable_task_spec()->ParseFromArray(
|
||||
fbs_message->resource_spec()->data(), fbs_message->resource_spec()->size()));
|
||||
|
||||
// Override the task dispatch to call back to the client instead of executing the
|
||||
// task directly on the worker. TODO(ekl) handle spilling case
|
||||
Task task(task_message);
|
||||
task.OnDispatchInstead([this, client](const std::shared_ptr<void> granted,
|
||||
const std::string &address, int port) {
|
||||
std::shared_ptr<Worker> client_worker = worker_pool_.GetRegisteredWorker(client);
|
||||
if (client_worker == nullptr) {
|
||||
client_worker = worker_pool_.GetRegisteredDriver(client);
|
||||
}
|
||||
if (client_worker == nullptr) {
|
||||
RAY_LOG(FATAL) << "TODO: Lost worker for lease request " << client;
|
||||
} else {
|
||||
client_worker->WorkerLeaseGranted(address, port);
|
||||
leased_workers_[port] = std::static_pointer_cast<Worker>(granted);
|
||||
}
|
||||
});
|
||||
SubmitTask(task, Lineage());
|
||||
}
|
||||
|
||||
void NodeManager::ProcessReturnWorkerMessage(const uint8_t *message_data) {
|
||||
// Read the resource spec submitted by the client.
|
||||
auto fbs_message = flatbuffers::GetRoot<protocol::ReturnWorkerRequest>(message_data);
|
||||
auto worker_port = fbs_message->worker_port();
|
||||
RAY_LOG(DEBUG) << "Return worker " << worker_port;
|
||||
std::shared_ptr<Worker> worker = leased_workers_[worker_port];
|
||||
leased_workers_.erase(worker_port);
|
||||
HandleWorkerAvailable(worker);
|
||||
}
|
||||
|
||||
void NodeManager::ProcessFetchOrReconstructMessage(
|
||||
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data) {
|
||||
auto message = flatbuffers::GetRoot<protocol::FetchOrReconstruct>(message_data);
|
||||
@@ -1955,7 +2003,12 @@ bool NodeManager::AssignTask(const Task &task) {
|
||||
ResourceIdSet resource_id_set =
|
||||
worker->GetTaskResourceIds().Plus(worker->GetLifetimeResourceIds());
|
||||
|
||||
worker->AssignTask(task, resource_id_set, finish_assign_task_callback);
|
||||
if (task.OnDispatch() != nullptr) {
|
||||
task.OnDispatch()(worker, initial_config_.node_manager_address, worker->Port());
|
||||
finish_assign_task_callback(Status::OK());
|
||||
} else {
|
||||
worker->AssignTask(task, resource_id_set, finish_assign_task_callback);
|
||||
}
|
||||
|
||||
// We assigned this task to a worker.
|
||||
// (Note this means that we sent the task to the worker. The assignment
|
||||
|
||||
@@ -394,6 +394,12 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
|
||||
/// \return Void.
|
||||
void HandleWorkerAvailable(const std::shared_ptr<LocalClientConnection> &client);
|
||||
|
||||
/// Handle the case that a worker is available.
|
||||
///
|
||||
/// \param worker The pointer to the worker
|
||||
/// \return Void.
|
||||
void HandleWorkerAvailable(const std::shared_ptr<Worker> &worker);
|
||||
|
||||
/// Handle a client that has disconnected. This can be called multiple times
|
||||
/// on the same client because this is triggered both when a client
|
||||
/// disconnects and when the node manager fails to write a message to the
|
||||
@@ -406,6 +412,20 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
|
||||
const std::shared_ptr<LocalClientConnection> &client,
|
||||
bool intentional_disconnect = false);
|
||||
|
||||
/// Process client message of RequestWorkerLease
|
||||
///
|
||||
/// \param client The client that sent the message.
|
||||
/// \param message_data A pointer to the message data.
|
||||
/// \return Void.
|
||||
void ProcessRequestWorkerLeaseMessage(
|
||||
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data);
|
||||
|
||||
/// Process client message of ReturnWorkerMessage
|
||||
///
|
||||
/// \param message_data A pointer to the message data.
|
||||
/// \return Void.
|
||||
void ProcessReturnWorkerMessage(const uint8_t *message_data);
|
||||
|
||||
/// Process client message of FetchOrReconstruct
|
||||
///
|
||||
/// \param client The client that sent the message.
|
||||
@@ -574,12 +594,15 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
|
||||
rpc::NodeManagerGrpcService node_manager_service_;
|
||||
|
||||
/// The `ClientCallManager` object that is shared by all `NodeManagerClient`s
|
||||
/// as well as all `WorkerTaskClient`s.
|
||||
/// as well as all `CoreWorkerClient`s.
|
||||
rpc::ClientCallManager client_call_manager_;
|
||||
|
||||
/// Map from node ids to clients of the remote node managers.
|
||||
std::unordered_map<ClientID, std::unique_ptr<rpc::NodeManagerClient>>
|
||||
remote_node_manager_clients_;
|
||||
|
||||
/// Map of workers leased out to direct call clients.
|
||||
std::unordered_map<int, std::shared_ptr<Worker>> leased_workers_;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
@@ -374,3 +374,19 @@ ray::Status RayletClient::ReportActiveObjectIDs(
|
||||
|
||||
return conn_->WriteMessage(MessageType::ReportActiveObjectIDs, &fbb);
|
||||
}
|
||||
|
||||
ray::Status RayletClient::RequestWorkerLease(
|
||||
const ray::TaskSpecification &resource_spec) {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = ray::protocol::CreateWorkerLeaseRequest(
|
||||
fbb, fbb.CreateString(resource_spec.Serialize()));
|
||||
fbb.Finish(message);
|
||||
return conn_->WriteMessage(MessageType::RequestWorkerLease, &fbb);
|
||||
}
|
||||
|
||||
ray::Status RayletClient::ReturnWorker(int worker_port) {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = ray::protocol::CreateReturnWorkerRequest(fbb, worker_port);
|
||||
fbb.Finish(message);
|
||||
return conn_->WriteMessage(MessageType::ReturnWorker, &fbb);
|
||||
}
|
||||
|
||||
@@ -63,7 +63,21 @@ class RayletConnection {
|
||||
std::mutex write_mutex_;
|
||||
};
|
||||
|
||||
class RayletClient {
|
||||
/// Interface for leasing workers. Abstract for testing.
|
||||
class WorkerLeaseInterface {
|
||||
public:
|
||||
/// Requests a worker from the raylet. The callback will be sent via gRPC.
|
||||
/// \param resource_spec Resources that should be allocated for the worker.
|
||||
/// \return ray::Status
|
||||
virtual ray::Status RequestWorkerLease(const ray::TaskSpecification &resource_spec) = 0;
|
||||
|
||||
/// Returns a worker to the raylet.
|
||||
/// \param worker_port The local port of the worker on the raylet node.
|
||||
/// \return ray::Status
|
||||
virtual ray::Status ReturnWorker(int worker_port) = 0;
|
||||
};
|
||||
|
||||
class RayletClient : public WorkerLeaseInterface {
|
||||
public:
|
||||
/// Connect to the raylet.
|
||||
///
|
||||
@@ -185,6 +199,12 @@ class RayletClient {
|
||||
/// \return ray::Status
|
||||
ray::Status ReportActiveObjectIDs(const std::unordered_set<ObjectID> &object_ids);
|
||||
|
||||
/// Implements WorkerLeaseInterface.
|
||||
ray::Status RequestWorkerLease(const ray::TaskSpecification &resource_spec) override;
|
||||
|
||||
/// Implements WorkerLeaseInterface.
|
||||
ray::Status ReturnWorker(int worker_port) override;
|
||||
|
||||
Language GetLanguage() const { return language_; }
|
||||
|
||||
WorkerID GetWorkerID() const { return worker_id_; }
|
||||
|
||||
@@ -76,7 +76,8 @@ static inline Task ExampleTask(const std::vector<ObjectID> &arguments,
|
||||
uint64_t num_returns) {
|
||||
TaskSpecBuilder builder;
|
||||
builder.SetCommonTaskSpec(RandomTaskId(), Language::PYTHON, {"", "", ""}, JobID::Nil(),
|
||||
RandomTaskId(), 0, RandomTaskId(), num_returns, {}, {});
|
||||
RandomTaskId(), 0, RandomTaskId(), num_returns, false, {},
|
||||
{});
|
||||
for (const auto &arg : arguments) {
|
||||
builder.AddByRefArg(arg);
|
||||
}
|
||||
|
||||
@@ -25,8 +25,8 @@ Worker::Worker(const WorkerID &worker_id, pid_t pid, const Language &language, i
|
||||
client_call_manager_(client_call_manager),
|
||||
is_detached_actor_(false) {
|
||||
if (port_ > 0) {
|
||||
rpc_client_ = std::unique_ptr<rpc::WorkerTaskClient>(
|
||||
new rpc::WorkerTaskClient("127.0.0.1", port_, client_call_manager_));
|
||||
rpc_client_ = std::unique_ptr<rpc::CoreWorkerClient>(
|
||||
new rpc::CoreWorkerClient("127.0.0.1", port_, client_call_manager_));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,6 +172,24 @@ void Worker::DirectActorCallArgWaitComplete(int64_t tag) {
|
||||
}
|
||||
}
|
||||
|
||||
void Worker::WorkerLeaseGranted(const std::string &address, int port) {
|
||||
RAY_CHECK(!address.empty());
|
||||
RAY_CHECK(port_ > 0);
|
||||
rpc::WorkerLeaseGrantedRequest request;
|
||||
request.set_address(address);
|
||||
request.set_port(port);
|
||||
auto status = rpc_client_->WorkerLeaseGranted(
|
||||
request, [address, port](Status status, const rpc::WorkerLeaseGrantedReply &reply) {
|
||||
if (!status.ok()) {
|
||||
RAY_LOG(ERROR) << "Failed to reply to lease request: " << status.ToString()
|
||||
<< " for " << address << ":" << port;
|
||||
}
|
||||
});
|
||||
if (!status.ok()) {
|
||||
RAY_LOG(ERROR) << "Failed to reply to lease request: " << status.ToString();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
} // end namespace ray
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ray/common/task/scheduling_resources.h"
|
||||
#include "ray/common/task/task.h"
|
||||
#include "ray/common/task/task_common.h"
|
||||
#include "ray/rpc/worker/worker_client.h"
|
||||
#include "ray/rpc/worker/core_worker_client.h"
|
||||
|
||||
#include <unistd.h> // pid_t
|
||||
|
||||
@@ -67,6 +67,7 @@ class Worker {
|
||||
void AssignTask(const Task &task, const ResourceIdSet &resource_id_set,
|
||||
const std::function<void(Status)> finish_assign_callback);
|
||||
void DirectActorCallArgWaitComplete(int64_t tag);
|
||||
void WorkerLeaseGranted(const std::string &address, int port);
|
||||
|
||||
private:
|
||||
/// The worker's ID.
|
||||
@@ -100,11 +101,11 @@ class Worker {
|
||||
std::unordered_set<TaskID> blocked_task_ids_;
|
||||
/// The set of object IDs that are currently in use on the worker.
|
||||
std::unordered_set<ObjectID> active_object_ids_;
|
||||
/// The `ClientCallManager` object that is shared by `WorkerTaskClient` from all
|
||||
/// The `ClientCallManager` object that is shared by `CoreWorkerClient` from all
|
||||
/// workers.
|
||||
rpc::ClientCallManager &client_call_manager_;
|
||||
/// The rpc client to send tasks to this worker.
|
||||
std::unique_ptr<rpc::WorkerTaskClient> rpc_client_;
|
||||
std::unique_ptr<rpc::CoreWorkerClient> rpc_client_;
|
||||
/// Whether the worker is detached. This is applies when the worker is actor.
|
||||
/// Detached actor means the actor's creator can exit without killing this actor.
|
||||
bool is_detached_actor_;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef RAY_RPC_WORKER_CLIENT_H
|
||||
#define RAY_RPC_WORKER_CLIENT_H
|
||||
#ifndef RAY_RPC_CORE_WORKER_CLIENT_H
|
||||
#define RAY_RPC_CORE_WORKER_CLIENT_H
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
@@ -25,7 +25,7 @@ const int64_t kMaxBytesInFlight = 16 * 1024 * 1024;
|
||||
const int64_t kBaseRequestSize = 1024;
|
||||
|
||||
/// Get the estimated size in bytes of the given task.
|
||||
const static int64_t RequestSizeInBytes(const DirectActorAssignTaskRequest &request) {
|
||||
const static int64_t RequestSizeInBytes(const PushTaskRequest &request) {
|
||||
int64_t size = kBaseRequestSize;
|
||||
for (auto &arg : request.task_spec().args()) {
|
||||
size += arg.data().size();
|
||||
@@ -33,44 +33,87 @@ const static int64_t RequestSizeInBytes(const DirectActorAssignTaskRequest &requ
|
||||
return size;
|
||||
}
|
||||
|
||||
/// Abstract client interface for testing.
|
||||
class CoreWorkerClientInterface {
|
||||
public:
|
||||
/// This is called by the Raylet to assign a task to the worker.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[in] callback The callback function that handles reply.
|
||||
/// \return if the rpc call succeeds
|
||||
virtual ray::Status AssignTask(const AssignTaskRequest &request,
|
||||
const ClientCallback<AssignTaskReply> &callback) {
|
||||
return Status::NotImplemented("");
|
||||
}
|
||||
|
||||
/// Push an actor task directly from worker to worker.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[in] callback The callback function that handles reply.
|
||||
/// \return if the rpc call succeeds
|
||||
virtual ray::Status PushActorTask(std::unique_ptr<PushTaskRequest> request,
|
||||
const ClientCallback<PushTaskReply> &callback) {
|
||||
return Status::NotImplemented("");
|
||||
}
|
||||
|
||||
/// Similar to PushActorTask, but sets no ordering constraint. This is used to
|
||||
/// push non-actor tasks directly to a worker.
|
||||
virtual ray::Status PushNormalTask(std::unique_ptr<PushTaskRequest> request,
|
||||
const ClientCallback<PushTaskReply> &callback) {
|
||||
return Status::NotImplemented("");
|
||||
}
|
||||
|
||||
/// Notify a wait has completed for direct actor call arguments.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[in] callback The callback function that handles reply.
|
||||
/// \return if the rpc call succeeds
|
||||
virtual ray::Status DirectActorCallArgWaitComplete(
|
||||
const DirectActorCallArgWaitCompleteRequest &request,
|
||||
const ClientCallback<DirectActorCallArgWaitCompleteReply> &callback) {
|
||||
return Status::NotImplemented("");
|
||||
}
|
||||
|
||||
/// Grants a worker to the client.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[in] callback The callback function that handles reply.
|
||||
/// \return if the rpc call succeeds
|
||||
virtual ray::Status WorkerLeaseGranted(
|
||||
const WorkerLeaseGrantedRequest &request,
|
||||
const ClientCallback<WorkerLeaseGrantedReply> &callback) {
|
||||
return Status::NotImplemented("");
|
||||
}
|
||||
};
|
||||
|
||||
/// Client used for communicating with a remote worker server.
|
||||
class WorkerTaskClient : public std::enable_shared_from_this<WorkerTaskClient> {
|
||||
class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
|
||||
public CoreWorkerClientInterface {
|
||||
public:
|
||||
/// Constructor.
|
||||
///
|
||||
/// \param[in] address Address of the worker server.
|
||||
/// \param[in] port Port of the worker server.
|
||||
/// \param[in] client_call_manager The `ClientCallManager` used for managing requests.
|
||||
WorkerTaskClient(const std::string &address, const int port,
|
||||
CoreWorkerClient(const std::string &address, const int port,
|
||||
ClientCallManager &client_call_manager)
|
||||
: client_call_manager_(client_call_manager) {
|
||||
std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(
|
||||
address + ":" + std::to_string(port), grpc::InsecureChannelCredentials());
|
||||
stub_ = WorkerService::NewStub(channel);
|
||||
stub_ = CoreWorkerService::NewStub(channel);
|
||||
};
|
||||
|
||||
/// Assign a task to the work.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[in] callback The callback function that handles reply.
|
||||
/// \return if the rpc call succeeds
|
||||
ray::Status AssignTask(const AssignTaskRequest &request,
|
||||
const ClientCallback<AssignTaskReply> &callback) {
|
||||
auto call =
|
||||
client_call_manager_
|
||||
.CreateCall<WorkerService, AssignTaskRequest, AssignTaskReply>(
|
||||
*stub_, &WorkerService::Stub::PrepareAsyncAssignTask, request, callback);
|
||||
const ClientCallback<AssignTaskReply> &callback) override {
|
||||
auto call = client_call_manager_
|
||||
.CreateCall<CoreWorkerService, AssignTaskRequest, AssignTaskReply>(
|
||||
*stub_, &CoreWorkerService::Stub::PrepareAsyncAssignTask, request,
|
||||
callback);
|
||||
return call->GetStatus();
|
||||
}
|
||||
|
||||
/// Push a task.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[in] callback The callback function that handles reply.
|
||||
/// \return if the rpc call succeeds
|
||||
ray::Status DirectActorAssignTask(
|
||||
std::unique_ptr<DirectActorAssignTaskRequest> request,
|
||||
const ClientCallback<DirectActorAssignTaskReply> &callback) {
|
||||
ray::Status PushActorTask(std::unique_ptr<PushTaskRequest> request,
|
||||
const ClientCallback<PushTaskReply> &callback) override {
|
||||
request->set_sequence_number(request->task_spec().actor_task_spec().actor_counter());
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
@@ -85,20 +128,36 @@ class WorkerTaskClient : public std::enable_shared_from_this<WorkerTaskClient> {
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
/// Notify a wait has completed for direct actor call arguments.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[in] callback The callback function that handles reply.
|
||||
/// \return if the rpc call succeeds
|
||||
ray::Status PushNormalTask(std::unique_ptr<PushTaskRequest> request,
|
||||
const ClientCallback<PushTaskReply> &callback) override {
|
||||
request->set_sequence_number(-1);
|
||||
request->set_client_processed_up_to(-1);
|
||||
auto call = client_call_manager_
|
||||
.CreateCall<CoreWorkerService, PushTaskRequest, PushTaskReply>(
|
||||
*stub_, &CoreWorkerService::Stub::PrepareAsyncPushTask, *request,
|
||||
callback);
|
||||
return call->GetStatus();
|
||||
}
|
||||
|
||||
ray::Status DirectActorCallArgWaitComplete(
|
||||
const DirectActorCallArgWaitCompleteRequest &request,
|
||||
const ClientCallback<DirectActorCallArgWaitCompleteReply> &callback) {
|
||||
const ClientCallback<DirectActorCallArgWaitCompleteReply> &callback) override {
|
||||
auto call = client_call_manager_.CreateCall<CoreWorkerService,
|
||||
DirectActorCallArgWaitCompleteRequest,
|
||||
DirectActorCallArgWaitCompleteReply>(
|
||||
*stub_, &CoreWorkerService::Stub::PrepareAsyncDirectActorCallArgWaitComplete,
|
||||
request, callback);
|
||||
return call->GetStatus();
|
||||
}
|
||||
|
||||
ray::Status WorkerLeaseGranted(
|
||||
const WorkerLeaseGrantedRequest &request,
|
||||
const ClientCallback<WorkerLeaseGrantedReply> &callback) override {
|
||||
auto call =
|
||||
client_call_manager_
|
||||
.CreateCall<WorkerService, DirectActorCallArgWaitCompleteRequest,
|
||||
DirectActorCallArgWaitCompleteReply>(
|
||||
*stub_, &WorkerService::Stub::PrepareAsyncDirectActorCallArgWaitComplete,
|
||||
request, callback);
|
||||
client_call_manager_.CreateCall<CoreWorkerService, WorkerLeaseGrantedRequest,
|
||||
WorkerLeaseGrantedReply>(
|
||||
*stub_, &CoreWorkerService::Stub::PrepareAsyncWorkerLeaseGranted, request,
|
||||
callback);
|
||||
return call->GetStatus();
|
||||
}
|
||||
|
||||
@@ -122,11 +181,10 @@ class WorkerTaskClient : public std::enable_shared_from_this<WorkerTaskClient> {
|
||||
request->set_client_processed_up_to(max_finished_seq_no_);
|
||||
rpc_bytes_in_flight_ += task_size;
|
||||
|
||||
client_call_manager_.CreateCall<WorkerService, DirectActorAssignTaskRequest,
|
||||
DirectActorAssignTaskReply>(
|
||||
*stub_, &WorkerService::Stub::PrepareAsyncDirectActorAssignTask, *request,
|
||||
[this, this_ptr, seq_no, task_size, callback](
|
||||
Status status, const rpc::DirectActorAssignTaskReply &reply) {
|
||||
client_call_manager_.CreateCall<CoreWorkerService, PushTaskRequest, PushTaskReply>(
|
||||
*stub_, &CoreWorkerService::Stub::PrepareAsyncPushTask, *request,
|
||||
[this, this_ptr, seq_no, task_size, callback](Status status,
|
||||
const rpc::PushTaskReply &reply) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (seq_no > max_finished_seq_no_) {
|
||||
@@ -150,14 +208,13 @@ class WorkerTaskClient : public std::enable_shared_from_this<WorkerTaskClient> {
|
||||
std::mutex mutex_;
|
||||
|
||||
/// The gRPC-generated stub.
|
||||
std::unique_ptr<WorkerService::Stub> stub_;
|
||||
std::unique_ptr<CoreWorkerService::Stub> stub_;
|
||||
|
||||
/// The `ClientCallManager` used for managing requests.
|
||||
ClientCallManager &client_call_manager_;
|
||||
|
||||
/// Queue of requests to send.
|
||||
std::deque<std::pair<std::unique_ptr<DirectActorAssignTaskRequest>,
|
||||
ClientCallback<DirectActorAssignTaskReply>>>
|
||||
std::deque<std::pair<std::unique_ptr<PushTaskRequest>, ClientCallback<PushTaskReply>>>
|
||||
send_queue_ GUARDED_BY(mutex_);
|
||||
|
||||
/// The number of bytes currently in flight.
|
||||
@@ -174,4 +231,4 @@ class WorkerTaskClient : public std::enable_shared_from_this<WorkerTaskClient> {
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RPC_WORKER_CLIENT_H
|
||||
#endif // RAY_RPC_CORE_WORKER_CLIENT_H
|
||||
@@ -1,23 +1,23 @@
|
||||
#include "ray/rpc/worker/worker_server.h"
|
||||
#include "ray/rpc/worker/core_worker_server.h"
|
||||
#include "ray/core_worker/core_worker.h"
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
#define RAY_CORE_WORKER_RPC_HANDLER(HANDLER, CONCURRENCY) \
|
||||
std::unique_ptr<ServerCallFactory> HANDLER##_call_factory( \
|
||||
new ServerCallFactoryImpl<WorkerService, CoreWorker, HANDLER##Request, \
|
||||
HANDLER##Reply>( \
|
||||
service_, &WorkerService::AsyncService::Request##HANDLER, core_worker_, \
|
||||
&CoreWorker::Handle##HANDLER, cq, main_service_)); \
|
||||
server_call_factories_and_concurrencies->emplace_back( \
|
||||
#define RAY_CORE_WORKER_RPC_HANDLER(HANDLER, CONCURRENCY) \
|
||||
std::unique_ptr<ServerCallFactory> HANDLER##_call_factory( \
|
||||
new ServerCallFactoryImpl<CoreWorkerService, CoreWorker, HANDLER##Request, \
|
||||
HANDLER##Reply>( \
|
||||
service_, &CoreWorkerService::AsyncService::Request##HANDLER, core_worker_, \
|
||||
&CoreWorker::Handle##HANDLER, cq, main_service_)); \
|
||||
server_call_factories_and_concurrencies->emplace_back( \
|
||||
std::move(HANDLER##_call_factory), CONCURRENCY);
|
||||
|
||||
WorkerGrpcService::WorkerGrpcService(boost::asio::io_service &main_service,
|
||||
CoreWorker &core_worker)
|
||||
CoreWorkerGrpcService::CoreWorkerGrpcService(boost::asio::io_service &main_service,
|
||||
CoreWorker &core_worker)
|
||||
: GrpcService(main_service), core_worker_(core_worker){};
|
||||
|
||||
void WorkerGrpcService::InitServerCallFactories(
|
||||
void CoreWorkerGrpcService::InitServerCallFactories(
|
||||
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
|
||||
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
|
||||
*server_call_factories_and_concurrencies) {
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef RAY_RPC_WORKER_SERVER_H
|
||||
#define RAY_RPC_WORKER_SERVER_H
|
||||
#ifndef RAY_RPC_CORE_WORKER_SERVER_H
|
||||
#define RAY_RPC_CORE_WORKER_SERVER_H
|
||||
|
||||
#include "ray/rpc/grpc_server.h"
|
||||
#include "ray/rpc/server_call.h"
|
||||
@@ -13,14 +13,14 @@ class CoreWorker;
|
||||
|
||||
namespace rpc {
|
||||
|
||||
/// The `GrpcServer` for `WorkerService`.
|
||||
class WorkerGrpcService : public GrpcService {
|
||||
/// The `GrpcServer` for `CoreWorkerService`.
|
||||
class CoreWorkerGrpcService : public GrpcService {
|
||||
public:
|
||||
/// Constructor.
|
||||
///
|
||||
/// \param[in] main_service See super class.
|
||||
/// \param[in] handler The service handler that actually handle the requests.
|
||||
WorkerGrpcService(boost::asio::io_service &main_service, CoreWorker &core_worker);
|
||||
CoreWorkerGrpcService(boost::asio::io_service &main_service, CoreWorker &core_worker);
|
||||
|
||||
protected:
|
||||
grpc::Service &GetGrpcService() override { return service_; }
|
||||
@@ -32,7 +32,7 @@ class WorkerGrpcService : public GrpcService {
|
||||
|
||||
private:
|
||||
/// The grpc async service object.
|
||||
WorkerService::AsyncService service_;
|
||||
CoreWorkerService::AsyncService service_;
|
||||
|
||||
/// The core worker that actually handles the requests.
|
||||
CoreWorker &core_worker_;
|
||||
@@ -41,4 +41,4 @@ class WorkerGrpcService : public GrpcService {
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
|
||||
#endif
|
||||
#endif // RAY_RPC_CORE_WORKER_SERVER_H
|
||||
@@ -3,6 +3,8 @@
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ray/common/buffer.h"
|
||||
#include "ray/common/ray_object.h"
|
||||
#include "ray/util/util.h"
|
||||
|
||||
namespace ray {
|
||||
@@ -40,6 +42,20 @@ inline TaskID RandomTaskId() {
|
||||
return TaskID::FromBinary(data);
|
||||
}
|
||||
|
||||
std::shared_ptr<Buffer> GenerateRandomBuffer() {
|
||||
auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count();
|
||||
std::mt19937 gen(seed);
|
||||
std::uniform_int_distribution<> dis(1, 10);
|
||||
std::uniform_int_distribution<> value_dis(1, 255);
|
||||
|
||||
std::vector<uint8_t> arg1(dis(gen), value_dis(gen));
|
||||
return std::make_shared<LocalMemoryBuffer>(arg1.data(), arg1.size(), true);
|
||||
}
|
||||
|
||||
std::shared_ptr<RayObject> GenerateRandomObject() {
|
||||
return std::shared_ptr<RayObject>(new RayObject(GenerateRandomBuffer(), nullptr));
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_UTIL_TEST_UTIL_H
|
||||
|
||||
Reference in New Issue
Block a user