Minimal implementation of direct task calls (#6075)

This commit is contained in:
Eric Liang
2019-11-12 11:45:28 -08:00
committed by GitHub
parent 35d177f459
commit f3f86385d6
49 changed files with 1358 additions and 384 deletions
+1
View File
@@ -2,6 +2,7 @@
/bazel-*
/python/ray/core
/python/ray/pyarrow_files/
/python/ray/pickle5_files/
/python/build
/python/dist
/thirdparty/pkg/
+10
View File
@@ -404,6 +404,16 @@ cc_binary(
],
)
cc_test(
name = "direct_task_transport_test",
srcs = ["src/ray/core_worker/test/direct_task_transport_test.cc"],
copts = COPTS,
deps = [
":core_worker_lib",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "reference_count_test",
srcs = ["src/ray/core_worker/reference_count_test.cc"],
+8 -8
View File
@@ -70,29 +70,29 @@ format_changed() {
# could cause yapf to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
# `diff-filter=ACRM` and $MERGEBASE is to ensure we only format files that
# exist on both branches.
MERGEBASE="$(git merge-base upstream/master HEAD)"
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' | xargs -P 5 \
if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.py' &>/dev/null; then
git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
if which flake8 >/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' | xargs -P 5 \
git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \
flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E226,E24,E704,W503,W504,W605
fi
fi
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then
if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then
if which flake8 >/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \
git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \
flake8 --inline-quotes '"' --no-avoid-escape --exclude=python/ray/core/generated/,doc/source/conf.py,python/ray/cloudpickle/ --ignore=C408,E121,E123,E126,E211,E225,E226,E227,E24,E704,E999,W503,W504,W605
fi
fi
if which clang-format >/dev/null; then
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.cc' '*.h' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.cc' '*.h' | xargs -P 5 \
if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.cc' '*.h' &>/dev/null; then
git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.cc' '*.h' | xargs -P 5 \
clang-format -i
fi
fi
+32 -9
View File
@@ -93,6 +93,7 @@ from ray.ray_constants import (
DEFAULT_PUT_OBJECT_DELAY,
DEFAULT_PUT_OBJECT_RETRIES,
RAW_BUFFER_METADATA,
PICKLE_BUFFER_METADATA,
PICKLE5_BUFFER_METADATA,
)
@@ -334,6 +335,7 @@ cdef c_vector[c_string] string_vector_from_list(list string_list):
cdef void prepare_args(list args, c_vector[CTaskArg] *args_vector):
cdef:
c_string pickled_str
c_string metadata_str = PICKLE_BUFFER_METADATA
shared_ptr[CBuffer] arg_data
shared_ptr[CBuffer] arg_metadata
@@ -353,6 +355,11 @@ cdef void prepare_args(list args, c_vector[CTaskArg] *args_vector):
<uint8_t*>(pickled_str.data()),
pickled_str.size(),
True))
arg_metadata = dynamic_pointer_cast[
CBuffer, LocalMemoryBuffer](
make_shared[LocalMemoryBuffer](
<uint8_t*>(
metadata_str.data()), metadata_str.size(), True))
args_vector.push_back(
CTaskArg.PassByValue(
make_shared[CRayObject](arg_data, arg_metadata)))
@@ -436,8 +443,18 @@ cdef deserialize_args(
c_args[i].get().GetMetadata()).to_pybytes()
== RAW_BUFFER_METADATA):
args.append(data)
else:
elif (c_args[i].get().HasMetadata() and Buffer.make(
c_args[i].get().GetMetadata()).to_pybytes()
== PICKLE_BUFFER_METADATA):
# This is a pickled "simple python value" argument.
args.append(pickle.loads(data.to_pybytes()))
else:
# This is a Ray object inlined by the direct task submitter.
by_reference_ids.append(
ObjectID(arg_reference_ids[i].Binary()))
by_reference_indices.append(i)
by_reference_objects.push_back(c_args[i])
args.append(None)
# Passed by reference.
else:
by_reference_ids.append(
@@ -658,12 +675,14 @@ cdef shared_ptr[CBuffer] string_to_buffer(c_string& c_str):
cdef shared_ptr[CBuffer] empty_metadata
if c_str.size() == 0:
return empty_metadata
return dynamic_pointer_cast[CBuffer, LocalMemoryBuffer](
make_shared[LocalMemoryBuffer](<uint8_t*>(c_str.data()),
c_str.size(), True))
return dynamic_pointer_cast[
CBuffer, LocalMemoryBuffer](
make_shared[LocalMemoryBuffer](
<uint8_t*>(c_str.data()), c_str.size(), True))
cdef write_serialized_object(serialized_object, const shared_ptr[CBuffer]& buf):
cdef write_serialized_object(
serialized_object, const shared_ptr[CBuffer]& buf):
# avoid initializing pyarrow before raylet
from ray.serialization import Pickle5SerializedObject, RawSerializedObject
@@ -851,6 +870,7 @@ cdef class CoreWorker:
function_descriptor,
args,
int num_return_vals,
c_bool is_direct_call,
resources):
cdef:
unordered_map[c_string, double] c_resources
@@ -861,7 +881,8 @@ cdef class CoreWorker:
with self.profile_event(b"submit_task"):
prepare_resources(resources, &c_resources)
task_options = CTaskOptions(num_return_vals, c_resources)
task_options = CTaskOptions(
num_return_vals, is_direct_call, c_resources)
ray_function = CRayFunction(
LANGUAGE_PYTHON, string_vector_from_list(function_descriptor))
prepare_args(args, &args_vector)
@@ -925,7 +946,7 @@ cdef class CoreWorker:
with self.profile_event(b"submit_task"):
if num_method_cpus > 0:
c_resources[b"CPU"] = num_method_cpus
task_options = CTaskOptions(num_return_vals, c_resources)
task_options = CTaskOptions(num_return_vals, False, c_resources)
ray_function = CRayFunction(
LANGUAGE_PYTHON, string_vector_from_list(function_descriptor))
prepare_args(args, &args_vector)
@@ -1017,7 +1038,8 @@ cdef class CoreWorker:
context = worker.get_serialization_context()
serialized_object = context.serialize(output)
data_sizes.push_back(serialized_object.total_bytes)
metadatas.push_back(string_to_buffer(serialized_object.metadata))
metadatas.push_back(
string_to_buffer(serialized_object.metadata))
serialized_objects.append(serialized_object)
check_status(self.core_worker.get().AllocateReturnObjects(
@@ -1030,4 +1052,5 @@ cdef class CoreWorker:
if serialized_object is NoReturn:
returns[0][i].reset()
else:
write_serialized_object(serialized_object, returns[0][i].get().GetData())
write_serialized_object(
serialized_object, returns[0][i].get().GetData())
+3 -1
View File
@@ -168,7 +168,9 @@ class Dashboard(object):
raise ValueError(
"Dashboard static asset directory not found at '{}'. If "
"installing from source, please follow the additional steps "
"required to build the dashboard.".format(static_dir))
"required to build the dashboard: "
"cd python/ray/dashboard/client && npm ci && "
"npm run build".format(static_dir))
self.app.router.add_static("/static", static_dir)
self.app.router.add_get("/api/ray_config", ray_config)
+1 -1
View File
@@ -201,7 +201,7 @@ cdef extern from "ray/core_worker/common.h" nogil:
cdef cppclass CTaskOptions "ray::TaskOptions":
CTaskOptions()
CTaskOptions(int num_returns,
CTaskOptions(int num_returns, c_bool is_direct_call,
unordered_map[c_string, double] &resources)
cdef cppclass CActorCreationOptions "ray::ActorCreationOptions":
+1 -1
View File
@@ -148,7 +148,7 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
c_bool is_put()
c_bool IsDirectActorType()
c_bool IsDirectCallType()
int64_t ObjectIndex() const
+2 -2
View File
@@ -176,8 +176,8 @@ cdef class ObjectID(BaseID):
def hex(self):
return decode(self.data.Hex())
def is_direct_actor_type(self):
return self.data.IsDirectActorType()
def is_direct_call_type(self):
return self.data.IsDirectCallType()
def is_nil(self):
return self.data.IsNil()
+4
View File
@@ -179,6 +179,10 @@ LOG_MONITOR_MAX_OPEN_FILES = 200
# A constant used as object metadata to indicate the object is raw binary.
RAW_BUFFER_METADATA = b"RAW"
# A constant used as object metadata to indicate the object is pickled. This
# format is only ever used for Python inline task argument values.
PICKLE_BUFFER_METADATA = b"PICKLE"
# A constant used as object metadata to indicate the object is pickle5 format.
PICKLE5_BUFFER_METADATA = b"PICKLE5"
AUTOSCALER_RESOURCE_REQUEST_CHANNEL = b"autoscaler_resource_request"
+9 -3
View File
@@ -134,6 +134,7 @@ class RemoteFunction(object):
args=None,
kwargs=None,
num_return_vals=None,
is_direct_call=None,
num_cpus=None,
num_gpus=None,
memory=None,
@@ -155,6 +156,8 @@ class RemoteFunction(object):
if num_return_vals is None:
num_return_vals = self._num_return_vals
if is_direct_call is None:
is_direct_call = False
resources = ray.utils.resources_from_resource_arguments(
self._num_cpus, self._num_gpus, self._memory,
@@ -162,8 +165,11 @@ class RemoteFunction(object):
memory, object_store_memory, resources)
def invocation(args, kwargs):
list_args = ray.signature.flatten_args(self._function_signature,
args, kwargs)
if not args and not kwargs and not self._function_signature:
list_args = []
else:
list_args = ray.signature.flatten_args(
self._function_signature, args, kwargs)
if worker.mode == ray.worker.LOCAL_MODE:
object_ids = worker.local_mode_manager.execute(
@@ -172,7 +178,7 @@ class RemoteFunction(object):
else:
object_ids = worker.core_worker.submit_task(
self._function_descriptor_list, list_args, num_return_vals,
resources)
is_direct_call, resources)
if len(object_ids) == 1:
return object_ids[0]
+2 -4
View File
@@ -157,8 +157,7 @@ class SerializationContext(object):
serialization_context)
def id_serializer(obj):
if isinstance(obj,
ray.ObjectID) and obj.is_direct_actor_type():
if isinstance(obj, ray.ObjectID) and obj.is_direct_call_type():
raise NotImplementedError(
"Objects produced by direct actor calls cannot be "
"passed to other tasks as arguments.")
@@ -191,8 +190,7 @@ class SerializationContext(object):
custom_deserializer=actor_handle_deserializer)
def id_serializer(obj):
if isinstance(obj,
ray.ObjectID) and obj.is_direct_actor_type():
if isinstance(obj, ray.ObjectID) and obj.is_direct_call_type():
raise NotImplementedError(
"Objects produced by direct actor calls cannot be "
"passed to other tasks as arguments.")
-7
View File
@@ -7,7 +7,6 @@ import funcsigs
from funcsigs import Parameter
import logging
import ray
from ray.utils import is_cython
# Logger for this module. It should be configured at the entry point
@@ -136,12 +135,6 @@ def flatten_args(signature_parameters, args, kwargs):
[None, 1, None, 2, None, 3, "a", 4]
"""
for obj in args:
if isinstance(obj, ray.ObjectID) and obj.is_direct_actor_type():
raise NotImplementedError(
"Objects produced by direct actor calls cannot be "
"passed to other tasks as arguments.")
restored = _restore_parameters(signature_parameters)
reconstructed_signature = funcsigs.Signature(parameters=restored)
try:
+25 -4
View File
@@ -1190,6 +1190,31 @@ def test_get_dict(ray_start_regular):
assert result == expected
def test_direct_call_simple(ray_start_regular):
@ray.remote
def f(x):
return x + 1
f_direct = f.options(is_direct_call=True)
print("a")
assert ray.get(f_direct.remote(2)) == 3
print("b")
assert ray.get([f_direct.remote(i) for i in range(100)]) == list(
range(1, 101))
def test_direct_call_chain(ray_start_regular):
@ray.remote
def g(x):
return x + 1
g_direct = g.options(is_direct_call=True)
x = 0
for _ in range(100):
x = g_direct.remote(x)
assert ray.get(x) == 100
def test_direct_actor_enabled(ray_start_regular):
@ray.remote
class Actor(object):
@@ -1240,10 +1265,6 @@ def test_direct_actor_errors(ray_start_regular):
a = Actor._remote(is_direct_call=True)
# cannot pass returns to other methods directly
with pytest.raises(Exception):
ray.get(f.remote(a.f.remote(2)))
# cannot pass returns to other methods even in a list
with pytest.raises(Exception):
ray.get(f.remote([a.f.remote(2)]))
+7 -2
View File
@@ -4,6 +4,7 @@
#include <google/protobuf/map.h>
#include <google/protobuf/repeated_field.h>
#include <grpcpp/grpcpp.h>
#include <sstream>
#include "status.h"
namespace ray {
@@ -39,7 +40,7 @@ class MessageWrapper {
const Message &GetMessage() const { return *message_; }
/// Get reference of the protobuf message.
Message &GetMutableMessage() const { return *message_; }
Message &GetMutableMessage() { return *message_; }
/// Serialize the message to a string.
const std::string Serialize() const { return message_->SerializeAsString(); }
@@ -64,7 +65,11 @@ inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) {
if (grpc_status.ok()) {
return Status::OK();
} else {
return Status::IOError(grpc_status.error_message());
std::stringstream msg;
msg << grpc_status.error_code();
msg << ": ";
msg << grpc_status.error_message();
return Status::IOError(msg.str());
}
}
+3 -3
View File
@@ -18,7 +18,7 @@
namespace ray {
enum class TaskTransportType { RAYLET, DIRECT_ACTOR };
enum class TaskTransportType { RAYLET, DIRECT };
class TaskID;
class WorkerID;
@@ -290,8 +290,8 @@ class ObjectID : public BaseID<ObjectID> {
/// Return if this is a direct actor call object.
///
/// \return True if this is a direct actor object return.
bool IsDirectActorType() const {
return GetTransportType() == static_cast<uint8_t>(TaskTransportType::DIRECT_ACTOR);
bool IsDirectCallType() const {
return GetTransportType() == static_cast<uint8_t>(TaskTransportType::DIRECT);
}
/// Return this object id with a changed transport type.
+17
View File
@@ -9,6 +9,9 @@
namespace ray {
typedef std::function<void(const std::shared_ptr<void>, const std::string &, int)>
DispatchTaskCallback;
/// \class Task
///
/// A Task represents a Ray task and a specification of its execution (e.g.,
@@ -38,6 +41,13 @@ class Task {
ComputeDependencies();
}
/// Override dispatch behaviour.
void OnDispatchInstead(
std::function<void(const std::shared_ptr<void>, const std::string &, int)>
callback) {
on_dispatch_ = callback;
}
/// Get the mutable specification for the task. This specification may be
/// updated at runtime.
///
@@ -62,6 +72,9 @@ class Task {
/// \param task Task structure with updated dynamic information.
void CopyTaskExecutionSpec(const Task &task);
/// Returns the override dispatch task callback, or nullptr.
DispatchTaskCallback &OnDispatch() const { return on_dispatch_; }
std::string DebugString() const;
private:
@@ -78,6 +91,10 @@ class Task {
/// the TaskSpecification and execution dependencies from the
/// TaskExecutionSpecification.
std::vector<ObjectID> dependencies_;
/// For direct task calls, overrides the dispatch behaviour to send an RPC
/// back to the submitting worker.
mutable DispatchTaskCallback on_dispatch_ = nullptr;
};
} // namespace ray
+5 -2
View File
@@ -187,8 +187,11 @@ ObjectID TaskSpecification::ActorDummyObject() const {
}
bool TaskSpecification::IsDirectCall() const {
RAY_CHECK(IsActorCreationTask());
return message_->actor_creation_task_spec().is_direct_call();
if (IsActorCreationTask()) {
return message_->actor_creation_task_spec().is_direct_call();
} else {
return message_->is_direct_call();
}
}
int TaskSpecification::MaxActorConcurrency() const {
+2
View File
@@ -22,6 +22,8 @@ typedef std::pair<ResourceSet, FunctionDescriptor> SchedulingClassDescriptor;
typedef int SchedulingClass;
/// Wrapper class of protobuf `TaskSpec`, see `common.proto` for details.
/// TODO(ekl) we should consider passing around std::unique_ptrs<TaskSpecification>
/// instead `const TaskSpecification`, since this class is actually mutable.
class TaskSpecification : public MessageWrapper<rpc::TaskSpec> {
public:
/// Construct an empty task specification. This should not be used directly.
+2 -1
View File
@@ -27,7 +27,7 @@ class TaskSpecBuilder {
const TaskID &task_id, const Language &language,
const std::vector<std::string> &function_descriptor, const JobID &job_id,
const TaskID &parent_task_id, uint64_t parent_counter, const TaskID &caller_id,
uint64_t num_returns,
uint64_t num_returns, bool is_direct_call,
const std::unordered_map<std::string, double> &required_resources,
const std::unordered_map<std::string, double> &required_placement_resources) {
message_->set_type(TaskType::NORMAL_TASK);
@@ -41,6 +41,7 @@ class TaskSpecBuilder {
message_->set_parent_counter(parent_counter);
message_->set_caller_id(caller_id.Binary());
message_->set_num_returns(num_returns);
message_->set_is_direct_call(is_direct_call);
message_->mutable_required_resources()->insert(required_resources.begin(),
required_resources.end());
message_->mutable_required_placement_resources()->insert(
+5 -2
View File
@@ -84,11 +84,14 @@ class TaskArg {
/// Options for all tasks (actor and non-actor) except for actor creation.
struct TaskOptions {
TaskOptions() {}
TaskOptions(int num_returns, std::unordered_map<std::string, double> &resources)
: num_returns(num_returns), resources(resources) {}
TaskOptions(int num_returns, bool is_direct_call,
std::unordered_map<std::string, double> &resources)
: num_returns(num_returns), is_direct_call(is_direct_call), resources(resources) {}
/// Number of returns of this task.
int num_returns = 1;
/// Whether to use the direct task transport.
bool is_direct_call = false;
/// Resources required by this task.
std::unordered_map<std::string, double> resources;
};
+4 -3
View File
@@ -89,12 +89,13 @@ void WorkerContext::SetCurrentTask(const TaskSpecification &task_spec) {
if (task_spec.IsNormalTask()) {
RAY_CHECK(current_job_id_.IsNil());
SetCurrentJobId(task_spec.JobId());
current_task_is_direct_call_ = task_spec.IsDirectCall();
} else if (task_spec.IsActorCreationTask()) {
RAY_CHECK(current_job_id_.IsNil());
SetCurrentJobId(task_spec.JobId());
RAY_CHECK(current_actor_id_.IsNil());
current_actor_id_ = task_spec.ActorCreationId();
current_actor_use_direct_call_ = task_spec.IsDirectCall();
current_task_is_direct_call_ = task_spec.IsDirectCall();
current_actor_max_concurrency_ = task_spec.MaxActorConcurrency();
} else if (task_spec.IsActorTask()) {
RAY_CHECK(current_job_id_ == task_spec.JobId());
@@ -117,8 +118,8 @@ std::shared_ptr<const TaskSpecification> WorkerContext::GetCurrentTask() const {
const ActorID &WorkerContext::GetCurrentActorID() const { return current_actor_id_; }
bool WorkerContext::CurrentActorUseDirectCall() const {
return current_actor_use_direct_call_;
bool WorkerContext::CurrentTaskIsDirectCall() const {
return current_task_is_direct_call_;
}
int WorkerContext::CurrentActorMaxConcurrency() const {
+3 -3
View File
@@ -34,7 +34,7 @@ class WorkerContext {
const ActorID &GetCurrentActorID() const;
bool CurrentActorUseDirectCall() const;
bool CurrentTaskIsDirectCall() const;
int CurrentActorMaxConcurrency() const;
@@ -47,8 +47,8 @@ class WorkerContext {
const WorkerID worker_id_;
JobID current_job_id_;
ActorID current_actor_id_;
bool current_actor_use_direct_call_ = false;
int current_actor_max_concurrency_;
bool current_task_is_direct_call_ = false;
int current_actor_max_concurrency_ = 1;
private:
static WorkerThreadContext &GetThreadContext(bool for_main_thread = false);
+83 -58
View File
@@ -19,11 +19,16 @@ void BuildCommonTaskSpec(
// Build common task spec.
builder.SetCommonTaskSpec(task_id, function.GetLanguage(),
function.GetFunctionDescriptor(), job_id, current_task_id,
task_index, caller_id, num_returns, required_resources,
required_placement_resources);
task_index, caller_id, num_returns,
transport_type == ray::TaskTransportType::DIRECT,
required_resources, required_placement_resources);
// Set task arguments.
for (const auto &arg : args) {
if (arg.IsPassedByReference()) {
if (transport_type == ray::TaskTransportType::RAYLET) {
RAY_CHECK(!arg.GetReference().IsDirectCallType())
<< "NotImplemented: passing direct call objects to other tasks";
}
builder.AddByRefArg(arg.GetReference());
} else {
builder.AddByValueArg(arg.GetValue());
@@ -44,7 +49,7 @@ void GroupObjectIdsByStoreProvider(const std::vector<ObjectID> &object_ids,
absl::flat_hash_set<ObjectID> *plasma_object_ids,
absl::flat_hash_set<ObjectID> *memory_object_ids) {
for (const auto &object_id : object_ids) {
if (object_id.IsDirectActorType()) {
if (object_id.IsDirectCallType()) {
memory_object_ids->insert(object_id);
} else {
plasma_object_ids->insert(object_id);
@@ -70,11 +75,12 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
check_signals_(check_signals),
worker_context_(worker_type, job_id),
io_work_(io_service_),
client_call_manager_(new rpc::ClientCallManager(io_service_)),
heartbeat_timer_(io_service_),
worker_server_(WorkerTypeString(worker_type), 0 /* let grpc choose a port */),
core_worker_server_(WorkerTypeString(worker_type), 0 /* let grpc choose a port */),
gcs_client_(gcs_options),
client_call_manager_(io_service_),
memory_store_(std::make_shared<CoreWorkerMemoryStore>()),
memory_store_provider_(memory_store_),
task_execution_service_work_(task_execution_service_),
task_execution_callback_(task_execution_callback),
grpc_service_(io_service_, *this) {
@@ -95,24 +101,21 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
profiler_ = std::make_shared<worker::Profiler>(worker_context_, node_ip_address,
io_service_, gcs_client_);
// Initialize task execution.
// Initialize task receivers.
if (worker_type_ == WorkerType::WORKER) {
RAY_CHECK(task_execution_callback_ != nullptr);
// Initialize task receivers.
auto execute_task = std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3);
raylet_task_receiver_ = std::unique_ptr<CoreWorkerRayletTaskReceiver>(
new CoreWorkerRayletTaskReceiver(raylet_client_, execute_task, exit_handler));
direct_actor_task_receiver_ = std::unique_ptr<CoreWorkerDirectActorTaskReceiver>(
new CoreWorkerDirectActorTaskReceiver(worker_context_, task_execution_service_,
worker_server_, execute_task,
exit_handler));
worker_server_.RegisterService(grpc_service_);
direct_task_receiver_ =
std::unique_ptr<CoreWorkerDirectTaskReceiver>(new CoreWorkerDirectTaskReceiver(
worker_context_, task_execution_service_, execute_task, exit_handler));
}
// Start RPC server after all the task receivers are properly initialized.
worker_server_.Run();
core_worker_server_.RegisterService(grpc_service_);
core_worker_server_.Run();
// Initialize raylet client.
// TODO(zhijunfu): currently RayletClient would crash in its constructor if it cannot
@@ -120,15 +123,15 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
// so that the worker (java/python .etc) can retrieve and handle the error
// instead of crashing.
auto grpc_client = rpc::NodeManagerWorkerClient::make(
node_ip_address, node_manager_port, client_call_manager_);
node_ip_address, node_manager_port, *client_call_manager_);
raylet_client_ = std::unique_ptr<RayletClient>(new RayletClient(
std::move(grpc_client), raylet_socket,
WorkerID::FromBinary(worker_context_.GetWorkerID().Binary()),
(worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(),
language_, worker_server_.GetPort()));
language_, core_worker_server_.GetPort()));
// Unfortunately the raylet client has to be constructed after the receivers.
if (direct_actor_task_receiver_ != nullptr) {
direct_actor_task_receiver_->Init(*raylet_client_);
if (direct_task_receiver_ != nullptr) {
direct_task_receiver_->Init(*raylet_client_);
}
// Set timer to periodically send heartbeats containing active object IDs to the raylet.
@@ -149,7 +152,6 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
plasma_store_provider_.reset(
new CoreWorkerPlasmaStoreProvider(store_socket, raylet_client_, check_signals_));
memory_store_provider_.reset(new CoreWorkerMemoryStoreProvider(memory_store_));
// Create an entry for the driver task in the task table. This task is
// added immediately with status RUNNING. This allows us to push errors
@@ -162,10 +164,10 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
std::vector<std::string> empty_descriptor;
std::unordered_map<std::string, double> empty_resources;
const TaskID task_id = TaskID::ForDriverTask(worker_context_.GetCurrentJobID());
builder.SetCommonTaskSpec(task_id, language_, empty_descriptor,
worker_context_.GetCurrentJobID(),
TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()),
0, GetCallerId(), 0, empty_resources, empty_resources);
builder.SetCommonTaskSpec(
task_id, language_, empty_descriptor, worker_context_.GetCurrentJobID(),
TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()), 0, GetCallerId(), 0,
false, empty_resources, empty_resources);
std::shared_ptr<gcs::TaskTableData> data = std::make_shared<gcs::TaskTableData>();
data->mutable_task()->mutable_task_spec()->CopyFrom(builder.Build().GetMessage());
@@ -173,11 +175,18 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language,
SetCurrentTaskId(task_id);
}
// TODO(edoakes): why don't we just share the memory store provider?
direct_actor_submitter_ = std::unique_ptr<CoreWorkerDirectActorTaskSubmitter>(
new CoreWorkerDirectActorTaskSubmitter(
io_service_, std::unique_ptr<CoreWorkerMemoryStoreProvider>(
new CoreWorkerMemoryStoreProvider(memory_store_))));
new CoreWorkerDirectActorTaskSubmitter(*client_call_manager_,
memory_store_provider_));
direct_task_submitter_ =
std::unique_ptr<CoreWorkerDirectTaskSubmitter>(new CoreWorkerDirectTaskSubmitter(
*raylet_client_,
[this](WorkerAddress addr) {
return std::shared_ptr<rpc::CoreWorkerClient>(new rpc::CoreWorkerClient(
addr.first, addr.second, *client_call_manager_));
},
memory_store_provider_));
}
CoreWorker::~CoreWorker() {
@@ -308,9 +317,9 @@ Status CoreWorker::Get(const std::vector<ObjectID> &ids, const int64_t timeout_m
local_timeout_ms = std::max(static_cast<int64_t>(0),
timeout_ms - (current_time_ms() - start_time));
}
RAY_RETURN_NOT_OK(memory_store_provider_->Get(memory_object_ids, local_timeout_ms,
worker_context_.GetCurrentTaskID(),
&result_map, &got_exception));
RAY_RETURN_NOT_OK(memory_store_provider_.Get(memory_object_ids, local_timeout_ms,
worker_context_.GetCurrentTaskID(),
&result_map, &got_exception));
}
// If any of the objects have been promoted to plasma, then we retry their
@@ -334,9 +343,9 @@ Status CoreWorker::Get(const std::vector<ObjectID> &ids, const int64_t timeout_m
for (const auto &id : promoted_plasma_ids) {
auto it = result_map.find(id);
if (it == result_map.end()) {
result_map.erase(id.WithTransportType(TaskTransportType::DIRECT_ACTOR));
result_map.erase(id.WithTransportType(TaskTransportType::DIRECT));
} else {
result_map[id.WithTransportType(TaskTransportType::DIRECT_ACTOR)] = it->second;
result_map[id.WithTransportType(TaskTransportType::DIRECT)] = it->second;
}
result_map.erase(id);
}
@@ -360,10 +369,10 @@ Status CoreWorker::Get(const std::vector<ObjectID> &ids, const int64_t timeout_m
Status CoreWorker::Contains(const ObjectID &object_id, bool *has_object) {
bool found = false;
if (object_id.IsDirectActorType()) {
if (object_id.IsDirectCallType()) {
// Note that the memory store returns false if the object value is
// ErrorType::OBJECT_IN_PLASMA.
RAY_RETURN_NOT_OK(memory_store_provider_->Contains(object_id, &found));
RAY_RETURN_NOT_OK(memory_store_provider_.Contains(object_id, &found));
}
if (!found) {
// We check plasma as a fallback in all cases, since a direct call object
@@ -413,7 +422,7 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
if (memory_object_ids.size() > 0) {
// TODO(ekl) for memory objects that are ErrorType::OBJECT_IN_PLASMA, we should
// consider waiting on them in plasma as well to ensure they are local.
RAY_RETURN_NOT_OK(memory_store_provider_->Wait(
RAY_RETURN_NOT_OK(memory_store_provider_.Wait(
memory_object_ids, std::max(0, static_cast<int>(ready.size()) - num_objects),
/*timeout_ms=*/0, worker_context_.GetCurrentTaskID(), &ready));
}
@@ -431,8 +440,8 @@ Status CoreWorker::Wait(const std::vector<ObjectID> &ids, int num_objects,
}
if (memory_object_ids.size() > 0) {
RAY_RETURN_NOT_OK(
memory_store_provider_->Wait(memory_object_ids, num_objects, timeout_ms,
worker_context_.GetCurrentTaskID(), &ready));
memory_store_provider_.Wait(memory_object_ids, num_objects, timeout_ms,
worker_context_.GetCurrentTaskID(), &ready));
}
}
@@ -453,7 +462,7 @@ Status CoreWorker::Delete(const std::vector<ObjectID> &object_ids, bool local_on
RAY_RETURN_NOT_OK(plasma_store_provider_->Delete(plasma_object_ids, local_only,
delete_creating_tasks));
RAY_RETURN_NOT_OK(memory_store_provider_->Delete(memory_object_ids));
RAY_RETURN_NOT_OK(memory_store_provider_.Delete(memory_object_ids));
return Status::OK();
}
@@ -494,8 +503,7 @@ Status CoreWorker::SubmitTaskToRaylet(const TaskSpecification &task_spec) {
if (task_deps->size() > 0) {
for (size_t i = 0; i < num_returns; i++) {
reference_counter_.SetDependencies(task_spec.ReturnId(i, TaskTransportType::RAYLET),
task_deps);
reference_counter_.SetDependencies(task_spec.ReturnIdForPlasma(i), task_deps);
}
}
@@ -511,11 +519,19 @@ Status CoreWorker::SubmitTask(const RayFunction &function,
const auto task_id =
TaskID::ForNormalTask(worker_context_.GetCurrentJobID(),
worker_context_.GetCurrentTaskID(), next_task_index);
BuildCommonTaskSpec(builder, worker_context_.GetCurrentJobID(), task_id,
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(),
function, args, task_options.num_returns, task_options.resources,
{}, TaskTransportType::RAYLET, return_ids);
return SubmitTaskToRaylet(builder.Build());
// TODO(ekl) offload task building onto a thread pool for performance
BuildCommonTaskSpec(
builder, worker_context_.GetCurrentJobID(), task_id,
worker_context_.GetCurrentTaskID(), next_task_index, GetCallerId(), function, args,
task_options.num_returns, task_options.resources, {},
task_options.is_direct_call ? TaskTransportType::DIRECT : TaskTransportType::RAYLET,
return_ids);
if (task_options.is_direct_call) {
return direct_task_submitter_->SubmitTask(builder.Build());
} else {
return raylet_client_->SubmitTask(builder.Build());
}
}
Status CoreWorker::CreateActor(const RayFunction &function,
@@ -562,7 +578,7 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f
const bool is_direct_call = actor_handle->IsDirectCallActor();
const TaskTransportType transport_type =
is_direct_call ? TaskTransportType::DIRECT_ACTOR : TaskTransportType::RAYLET;
is_direct_call ? TaskTransportType::DIRECT : TaskTransportType::RAYLET;
// Build common task spec.
TaskSpecBuilder builder;
@@ -675,7 +691,7 @@ Status CoreWorker::AllocateReturnObjects(
bool object_already_exists = false;
std::shared_ptr<Buffer> data_buffer;
if (data_sizes[i] > 0) {
if (worker_context_.CurrentActorUseDirectCall() &&
if (worker_context_.CurrentTaskIsDirectCall() &&
static_cast<int64_t>(data_sizes[i]) <
RayConfig::instance().max_direct_call_object_size()) {
data_buffer = std::make_shared<LocalMemoryBuffer>(data_sizes[i]);
@@ -710,8 +726,8 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
std::vector<ObjectID> arg_reference_ids;
RAY_CHECK_OK(BuildArgsForExecutor(task_spec, &args, &arg_reference_ids));
const auto transport_type = worker_context_.CurrentActorUseDirectCall()
? TaskTransportType::DIRECT_ACTOR
const auto transport_type = worker_context_.CurrentTaskIsDirectCall()
? TaskTransportType::DIRECT
: TaskTransportType::RAYLET;
std::vector<ObjectID> return_ids;
for (size_t i = 0; i < task_spec.NumReturns(); i++) {
@@ -745,7 +761,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec,
RAY_LOG(FATAL) << "Task " << task_spec.TaskId() << " failed to seal object "
<< return_ids[i] << " in store: " << status.message();
}
} else if (!worker_context_.CurrentActorUseDirectCall()) {
} else if (!worker_context_.CurrentTaskIsDirectCall()) {
if (!Put(*return_objects->at(i), return_ids[i]).ok()) {
RAY_LOG(FATAL) << "Task " << task_spec.TaskId() << " failed to put object "
<< return_ids[i] << " in store: " << status.message();
@@ -816,7 +832,7 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task,
void CoreWorker::HandleAssignTask(const rpc::AssignTaskRequest &request,
rpc::AssignTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
if (worker_context_.CurrentActorUseDirectCall()) {
if (worker_context_.CurrentTaskIsDirectCall()) {
send_reply_callback(Status::Invalid("This actor only accepts direct calls."), nullptr,
nullptr);
return;
@@ -827,12 +843,11 @@ void CoreWorker::HandleAssignTask(const rpc::AssignTaskRequest &request,
}
}
void CoreWorker::HandleDirectActorAssignTask(
const rpc::DirectActorAssignTaskRequest &request,
rpc::DirectActorAssignTaskReply *reply, rpc::SendReplyCallback send_reply_callback) {
void CoreWorker::HandlePushTask(const rpc::PushTaskRequest &request,
rpc::PushTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
task_execution_service_.post([=] {
direct_actor_task_receiver_->HandleDirectActorAssignTask(request, reply,
send_reply_callback);
direct_task_receiver_->HandlePushTask(request, reply, send_reply_callback);
});
}
@@ -841,9 +856,19 @@ void CoreWorker::HandleDirectActorCallArgWaitComplete(
rpc::DirectActorCallArgWaitCompleteReply *reply,
rpc::SendReplyCallback send_reply_callback) {
task_execution_service_.post([=] {
direct_actor_task_receiver_->HandleDirectActorCallArgWaitComplete(
request, reply, send_reply_callback);
direct_task_receiver_->HandleDirectActorCallArgWaitComplete(request, reply,
send_reply_callback);
});
}
void CoreWorker::HandleWorkerLeaseGranted(const rpc::WorkerLeaseGrantedRequest &request,
rpc::WorkerLeaseGrantedReply *reply,
rpc::SendReplyCallback send_reply_callback) {
// Run this directly since the main thread may be tied up processing a task and
// we need to still continue processing these scheduling operations in the backend.
direct_task_submitter_->HandleWorkerLeaseGranted(
std::make_pair(request.address(), request.port()));
send_reply_callback(Status::OK(), nullptr, nullptr);
}
} // namespace ray
+33 -25
View File
@@ -14,22 +14,24 @@
#include "ray/core_worker/store_provider/memory_store_provider.h"
#include "ray/core_worker/store_provider/plasma_store_provider.h"
#include "ray/core_worker/transport/direct_actor_transport.h"
#include "ray/core_worker/transport/direct_task_transport.h"
#include "ray/core_worker/transport/raylet_transport.h"
#include "ray/gcs/redis_gcs_client.h"
#include "ray/raylet/raylet_client.h"
#include "ray/rpc/node_manager/node_manager_client.h"
#include "ray/rpc/worker/worker_client.h"
#include "ray/rpc/worker/worker_server.h"
#include "ray/rpc/worker/core_worker_client.h"
#include "ray/rpc/worker/core_worker_server.h"
/// The set of gRPC handlers and their associated level of concurrency. If you want to
/// add a new call to the worker gRPC server, do the following:
/// 1) Add the rpc to the WorkerService in core_worker.proto, e.g., "ExampleCall"
/// 1) Add the rpc to the CoreWorkerService in core_worker.proto, e.g., "ExampleCall"
/// 2) Add a new handler to the macro below: "RAY_CORE_WORKER_RPC_HANDLER(ExampleCall, 1)"
/// 3) Add a method to the CoreWorker class below: "CoreWorker::HandleExampleCall"
#define RAY_CORE_WORKER_RPC_HANDLERS \
RAY_CORE_WORKER_RPC_HANDLER(AssignTask, 5) \
RAY_CORE_WORKER_RPC_HANDLER(DirectActorAssignTask, 9999) \
RAY_CORE_WORKER_RPC_HANDLER(DirectActorCallArgWaitComplete, 100)
#define RAY_CORE_WORKER_RPC_HANDLERS \
RAY_CORE_WORKER_RPC_HANDLER(AssignTask, 5) \
RAY_CORE_WORKER_RPC_HANDLER(PushTask, 9999) \
RAY_CORE_WORKER_RPC_HANDLER(DirectActorCallArgWaitComplete, 100) \
RAY_CORE_WORKER_RPC_HANDLER(WorkerLeaseGranted, 5)
namespace ray {
@@ -319,29 +321,32 @@ class CoreWorker {
const std::vector<std::shared_ptr<Buffer>> &metadatas,
std::vector<std::shared_ptr<RayObject>> *return_objects);
/* Handlers for the worker's gRPC server. These are executed on the io_service_ and post
* work to the appropriate event loop.
/**
* The following methods are handlers for the core worker's gRPC server, which follow
* a macro-generated call convention. These are executed on the io_service_ and
* post work to the appropriate event loop.
*/
/// Handle an "AssignTask" event corresponding to scheduling a normal or an actor task
/// on this worker from the raylet.
/// Implements gRPC server handler.
void HandleAssignTask(const rpc::AssignTaskRequest &request,
rpc::AssignTaskReply *reply,
rpc::SendReplyCallback send_reply_callback);
/// Handle a "DirectActorAssignTask" event corresponding to scheduling an actor task
/// on this worker from another worker.
void HandleDirectActorAssignTask(const rpc::DirectActorAssignTaskRequest &request,
rpc::DirectActorAssignTaskReply *reply,
rpc::SendReplyCallback send_reply_callback);
/// Implements gRPC server handler.
void HandlePushTask(const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
rpc::SendReplyCallback send_reply_callback);
/// Handle a "DirectActorAssignTask" event corresponding to the raylet notifiying this
/// worker that an argument is ready.
/// Implements gRPC server handler.
void HandleDirectActorCallArgWaitComplete(
const rpc::DirectActorCallArgWaitCompleteRequest &request,
rpc::DirectActorCallArgWaitCompleteReply *reply,
rpc::SendReplyCallback send_reply_callback);
/// Implements gRPC server handler.
void HandleWorkerLeaseGranted(const rpc::WorkerLeaseGrantedRequest &request,
rpc::WorkerLeaseGrantedReply *reply,
rpc::SendReplyCallback send_reply_callback);
private:
/// Run the io_service_ event loop. This should be called in a background thread.
void RunIOService();
@@ -446,19 +451,19 @@ class CoreWorker {
/// Keeps the io_service_ alive.
boost::asio::io_service::work io_work_;
/// Shared client call manager.
std::unique_ptr<rpc::ClientCallManager> client_call_manager_;
/// Timer used to periodically send heartbeat containing active object IDs to the
/// raylet.
boost::asio::steady_timer heartbeat_timer_;
/// RPC server used to receive tasks to execute.
rpc::GrpcServer worker_server_;
rpc::GrpcServer core_worker_server_;
// Client to the GCS shared by core worker interfaces.
gcs::RedisGcsClient gcs_client_;
/// The `ClientCallManager` object that is shared by all `NodeManagerClient`s.
rpc::ClientCallManager client_call_manager_;
// Client to the raylet shared by core worker interfaces.
std::unique_ptr<RayletClient> raylet_client_;
@@ -479,7 +484,7 @@ class CoreWorker {
std::unique_ptr<CoreWorkerPlasmaStoreProvider> plasma_store_provider_;
/// In-memory store interface.
std::unique_ptr<CoreWorkerMemoryStoreProvider> memory_store_provider_;
CoreWorkerMemoryStoreProvider memory_store_provider_;
///
/// Fields related to task submission.
@@ -488,6 +493,9 @@ class CoreWorker {
// Interface to submit tasks directly to other actors.
std::unique_ptr<CoreWorkerDirectActorTaskSubmitter> direct_actor_submitter_;
// Interface to submit non-actor tasks directly to leased workers.
std::unique_ptr<CoreWorkerDirectTaskSubmitter> direct_task_submitter_;
/// Map from actor ID to a handle to that actor.
absl::flat_hash_map<ActorID, std::unique_ptr<ActorHandle>> actor_handles_;
@@ -519,10 +527,10 @@ class CoreWorker {
std::unique_ptr<CoreWorkerRayletTaskReceiver> raylet_task_receiver_;
/// Common rpc service for all worker modules.
rpc::WorkerGrpcService grpc_service_;
rpc::CoreWorkerGrpcService grpc_service_;
// Interface that receives tasks from direct actor calls.
std::unique_ptr<CoreWorkerDirectActorTaskReceiver> direct_actor_task_receiver_;
std::unique_ptr<CoreWorkerDirectTaskReceiver> direct_task_receiver_;
friend class CoreWorkerTest;
};
@@ -106,32 +106,66 @@ std::shared_ptr<RayObject> GetRequest::Get(const ObjectID &object_id) const {
CoreWorkerMemoryStore::CoreWorkerMemoryStore() {}
Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &object) {
std::unique_lock<std::mutex> lock(lock_);
auto iter = objects_.find(object_id);
if (iter != objects_.end()) {
return Status::ObjectExists("object already exists in the memory store");
void CoreWorkerMemoryStore::GetAsync(
const ObjectID &object_id, std::function<void(std::shared_ptr<RayObject>)> callback) {
std::shared_ptr<RayObject> ptr;
{
absl::MutexLock lock(&mu_);
auto iter = objects_.find(object_id);
if (iter != objects_.end()) {
ptr = iter->second;
} else {
object_async_get_requests_[object_id].push_back(callback);
}
}
// It's important for performance to run the callback outside the lock.
if (ptr != nullptr) {
callback(ptr);
}
}
Status CoreWorkerMemoryStore::Put(const ObjectID &object_id, const RayObject &object) {
std::vector<std::function<void(std::shared_ptr<RayObject>)>> async_callbacks;
auto object_entry =
std::make_shared<RayObject>(object.GetData(), object.GetMetadata(), true);
bool should_add_entry = true;
auto object_request_iter = object_get_requests_.find(object_id);
if (object_request_iter != object_get_requests_.end()) {
auto &get_requests = object_request_iter->second;
for (auto &get_request : get_requests) {
get_request->Set(object_id, object_entry);
if (get_request->ShouldRemoveObjects()) {
should_add_entry = false;
{
absl::MutexLock lock(&mu_);
auto iter = objects_.find(object_id);
if (iter != objects_.end()) {
return Status::ObjectExists("object already exists in the memory store");
}
auto async_callback_it = object_async_get_requests_.find(object_id);
if (async_callback_it != object_async_get_requests_.end()) {
auto &callbacks = async_callback_it->second;
async_callbacks = std::move(callbacks);
object_async_get_requests_.erase(async_callback_it);
}
bool should_add_entry = true;
auto object_request_iter = object_get_requests_.find(object_id);
if (object_request_iter != object_get_requests_.end()) {
auto &get_requests = object_request_iter->second;
for (auto &get_request : get_requests) {
get_request->Set(object_id, object_entry);
if (get_request->ShouldRemoveObjects()) {
should_add_entry = false;
}
}
}
if (should_add_entry) {
// If there is no existing get request, then add the `RayObject` to map.
objects_.emplace(object_id, object_entry);
}
}
if (should_add_entry) {
// If there is no existing get request, then add the `RayObject` to map.
objects_.emplace(object_id, object_entry);
// It's important for performance to run the callbacks outside the lock.
for (const auto &cb : async_callbacks) {
cb(object_entry);
}
return Status::OK();
}
@@ -147,7 +181,7 @@ Status CoreWorkerMemoryStore::Get(const std::vector<ObjectID> &object_ids,
absl::flat_hash_set<ObjectID> remaining_ids;
absl::flat_hash_set<ObjectID> ids_to_remove;
std::unique_lock<std::mutex> lock(lock_);
absl::MutexLock lock(&mu_);
// Check for existing objects and see if this get request can be fullfilled.
for (size_t i = 0; i < object_ids.size(); i++) {
const auto &object_id = object_ids[i];
@@ -192,7 +226,7 @@ Status CoreWorkerMemoryStore::Get(const std::vector<ObjectID> &object_ids,
get_request->Wait(timeout_ms);
{
std::unique_lock<std::mutex> lock(lock_);
absl::MutexLock lock(&mu_);
// Populate results.
for (size_t i = 0; i < object_ids.size(); i++) {
const auto &object_id = object_ids[i];
@@ -223,14 +257,14 @@ Status CoreWorkerMemoryStore::Get(const std::vector<ObjectID> &object_ids,
}
void CoreWorkerMemoryStore::Delete(const std::vector<ObjectID> &object_ids) {
std::unique_lock<std::mutex> lock(lock_);
absl::MutexLock lock(&mu_);
for (const auto &object_id : object_ids) {
objects_.erase(object_id);
}
}
bool CoreWorkerMemoryStore::Contains(const ObjectID &object_id) {
std::unique_lock<std::mutex> lock(lock_);
absl::MutexLock lock(&mu_);
auto it = objects_.find(object_id);
// If obj is in plasma, we defer to the plasma store for the Contains() call.
return it != objects_.end() && !it->second->IsInPlasmaError();
@@ -3,6 +3,7 @@
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/synchronization/mutex.h"
#include "ray/common/id.h"
#include "ray/common/status.h"
#include "ray/core_worker/common.h"
@@ -39,6 +40,15 @@ class CoreWorkerMemoryStore {
Status Get(const std::vector<ObjectID> &object_ids, int num_objects, int64_t timeout_ms,
bool remove_after_get, std::vector<std::shared_ptr<RayObject>> *results);
/// Asynchronously get an object from the object store. The object will not be removed
/// from storage after GetAsync (TODO(ekl): integrate this with object GC).
///
/// \param[in] object_id The object id to get.
/// \param[in] callback The callback to run with the reference to the retrieved
/// object value once available.
void GetAsync(const ObjectID &object_id,
std::function<void(std::shared_ptr<RayObject>)> callback);
/// Delete a list of objects from the object store.
///
/// \param[in] object_ids IDs of the objects to delete.
@@ -53,14 +63,19 @@ class CoreWorkerMemoryStore {
private:
/// Map from object ID to `RayObject`.
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> objects_;
absl::flat_hash_map<ObjectID, std::shared_ptr<RayObject>> objects_ GUARDED_BY(mu_);
/// Map from object ID to its get requests.
absl::flat_hash_map<ObjectID, std::vector<std::shared_ptr<GetRequest>>>
object_get_requests_;
object_get_requests_ GUARDED_BY(mu_);
/// Map from object ID to its async get requests.
absl::flat_hash_map<ObjectID,
std::vector<std::function<void(std::shared_ptr<RayObject>)>>>
object_async_get_requests_ GUARDED_BY(mu_);
/// Protect the two maps above.
std::mutex lock_;
absl::Mutex mu_;
};
} // namespace ray
@@ -19,6 +19,11 @@ class CoreWorkerMemoryStoreProvider {
public:
CoreWorkerMemoryStoreProvider(std::shared_ptr<CoreWorkerMemoryStore> store);
void GetAsync(const ObjectID &object_id,
std::function<void(std::shared_ptr<RayObject>)> callback) {
store_->GetAsync(object_id, callback);
}
Status Put(const RayObject &object, const ObjectID &object_id);
Status Get(const absl::flat_hash_set<ObjectID> &object_ids, int64_t timeout_ms,
+16 -27
View File
@@ -44,16 +44,6 @@ static void flushall_redis(void) {
redisFree(context);
}
std::shared_ptr<Buffer> GenerateRandomBuffer() {
auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count();
std::mt19937 gen(seed);
std::uniform_int_distribution<> dis(1, 10);
std::uniform_int_distribution<> value_dis(1, 255);
std::vector<uint8_t> arg1(dis(gen), value_dis(gen));
return std::make_shared<LocalMemoryBuffer>(arg1.data(), arg1.size(), true);
}
ActorID CreateActorHelper(CoreWorker &worker,
std::unordered_map<std::string, double> &resources,
bool is_direct_call, uint64_t max_reconstructions) {
@@ -279,16 +269,15 @@ void CoreWorkerTest::TestActorTask(std::unordered_map<std::string, double> &reso
args.emplace_back(
TaskArg::PassByValue(std::make_shared<RayObject>(buffer2, nullptr)));
TaskOptions options{1, resources};
TaskOptions options{1, false, resources};
std::vector<ObjectID> return_ids;
RayFunction func(ray::Language::PYTHON, {});
RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids));
ASSERT_EQ(return_ids.size(), 1);
ASSERT_TRUE(return_ids[0].IsReturnObject());
ASSERT_EQ(
static_cast<TaskTransportType>(return_ids[0].GetTransportType()),
is_direct_call ? TaskTransportType::DIRECT_ACTOR : TaskTransportType::RAYLET);
ASSERT_EQ(static_cast<TaskTransportType>(return_ids[0].GetTransportType()),
is_direct_call ? TaskTransportType::DIRECT : TaskTransportType::RAYLET);
std::vector<std::shared_ptr<ray::RayObject>> results;
RAY_CHECK_OK(driver.Get(return_ids, -1, &results));
@@ -320,7 +309,7 @@ void CoreWorkerTest::TestActorTask(std::unordered_map<std::string, double> &reso
args.emplace_back(
TaskArg::PassByValue(std::make_shared<RayObject>(buffer2, nullptr)));
TaskOptions options{1, resources};
TaskOptions options{1, false, resources};
std::vector<ObjectID> return_ids;
RayFunction func(ray::Language::PYTHON, {});
auto status = driver.SubmitActorTask(actor_id, func, args, options, &return_ids);
@@ -380,7 +369,7 @@ void CoreWorkerTest::TestActorReconstruction(
args.emplace_back(
TaskArg::PassByValue(std::make_shared<RayObject>(buffer1, nullptr)));
TaskOptions options{1, resources};
TaskOptions options{1, false, resources};
std::vector<ObjectID> return_ids;
RayFunction func(ray::Language::PYTHON, {});
@@ -425,7 +414,7 @@ void CoreWorkerTest::TestActorFailure(std::unordered_map<std::string, double> &r
args.emplace_back(
TaskArg::PassByValue(std::make_shared<RayObject>(buffer1, nullptr)));
TaskOptions options{1, resources};
TaskOptions options{1, false, resources};
std::vector<ObjectID> return_ids;
RayFunction func(ray::Language::PYTHON, {});
@@ -486,7 +475,7 @@ TEST_F(ZeroNodeTest, TestTaskArg) {
ASSERT_EQ(*data, *buffer);
}
// Performance batchmark for `DirectActorAssignTaskRequest` creation.
// Performance batchmark for `PushTaskRequest` creation.
TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
// Create a dummy actor handle, and then create a number of `TaskSpec`
// to benchmark performance.
@@ -505,20 +494,21 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
function.GetFunctionDescriptor());
// Manually create `num_tasks` task specs, and for each of them create a
// `DirectActorAssignTaskRequest`, this is to batch performance of TaskSpec
// `PushTaskRequest`, this is to batch performance of TaskSpec
// creation/copy/destruction.
int64_t start_ms = current_time_ms();
const auto num_tasks = 10000 * 10;
RAY_LOG(INFO) << "start creating " << num_tasks << " DirectActorAssignTaskRequests";
RAY_LOG(INFO) << "start creating " << num_tasks << " PushTaskRequests";
for (int i = 0; i < num_tasks; i++) {
TaskOptions options{1, resources};
TaskOptions options{1, false, resources};
std::vector<ObjectID> return_ids;
auto num_returns = options.num_returns;
TaskSpecBuilder builder;
builder.SetCommonTaskSpec(RandomTaskId(), function.GetLanguage(),
function.GetFunctionDescriptor(), job_id, RandomTaskId(), 0,
RandomTaskId(), num_returns, resources, resources);
RandomTaskId(), num_returns, /*is_direct*/ false, resources,
resources);
// Set task arguments.
for (const auto &arg : args) {
if (arg.IsPassedByReference()) {
@@ -531,14 +521,13 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) {
actor_handle.SetActorTaskSpec(builder, TaskTransportType::RAYLET,
ObjectID::FromRandom());
const auto &task_spec = builder.Build();
auto task_spec = builder.Build();
ASSERT_TRUE(task_spec.IsActorTask());
auto request = std::unique_ptr<rpc::DirectActorAssignTaskRequest>(
new rpc::DirectActorAssignTaskRequest);
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
}
RAY_LOG(INFO) << "Finish creating " << num_tasks << " DirectActorAssignTaskRequests"
RAY_LOG(INFO) << "Finish creating " << num_tasks << " PushTaskRequests"
<< ", which takes " << current_time_ms() - start_ms << " ms";
}
@@ -566,7 +555,7 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) {
sizeof(array));
args.emplace_back(TaskArg::PassByValue(std::make_shared<RayObject>(buffer, nullptr)));
TaskOptions options{1, resources};
TaskOptions options{1, false, resources};
std::vector<ObjectID> return_ids;
RayFunction func(ray::Language::PYTHON, {});
@@ -0,0 +1,272 @@
#include "gtest/gtest.h"
#include "ray/common/task/task_spec.h"
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
#include "ray/core_worker/store_provider/memory_store_provider.h"
#include "ray/core_worker/transport/direct_task_transport.h"
#include "ray/raylet/raylet_client.h"
#include "ray/rpc/worker/core_worker_client.h"
#include "src/ray/util/test_util.h"
namespace ray {
class MockWorkerClient : public rpc::CoreWorkerClientInterface {
public:
ray::Status PushNormalTask(
std::unique_ptr<rpc::PushTaskRequest> request,
const rpc::ClientCallback<rpc::PushTaskReply> &callback) override {
callbacks.push_back(callback);
return Status::OK();
}
std::vector<rpc::ClientCallback<rpc::PushTaskReply>> callbacks;
};
class MockRayletClient : public WorkerLeaseInterface {
public:
ray::Status ReturnWorker(int worker_port) {
num_workers_returned += 1;
return Status::OK();
}
ray::Status RequestWorkerLease(const ray::TaskSpecification &resource_spec) {
num_workers_requested += 1;
return Status::OK();
}
int num_workers_requested = 0;
int num_workers_returned = 0;
};
TEST(LocalDependencyResolverTest, TestNoDependencies) {
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
CoreWorkerMemoryStoreProvider store(ptr);
LocalDependencyResolver resolver(store);
TaskSpecification task;
bool ok = false;
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
ASSERT_TRUE(ok);
}
TEST(LocalDependencyResolverTest, TestIgnorePlasmaDependencies) {
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
CoreWorkerMemoryStoreProvider store(ptr);
LocalDependencyResolver resolver(store);
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::RAYLET);
TaskSpecification task;
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
bool ok = false;
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
// We ignore and don't block on plasma dependencies.
ASSERT_TRUE(ok);
ASSERT_EQ(resolver.NumPendingTasks(), 0);
}
TEST(LocalDependencyResolverTest, TestInlineLocalDependencies) {
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
CoreWorkerMemoryStoreProvider store(ptr);
LocalDependencyResolver resolver(store);
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
auto data = GenerateRandomObject();
// Ensure the data is already present in the local store.
ASSERT_TRUE(store.Put(*data, obj1).ok());
ASSERT_TRUE(store.Put(*data, obj2).ok());
TaskSpecification task;
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
task.GetMutableMessage().add_args()->add_object_ids(obj2.Binary());
bool ok = false;
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
// Tests that the task proto was rewritten to have inline argument values.
ASSERT_TRUE(ok);
ASSERT_FALSE(task.ArgByRef(0));
ASSERT_FALSE(task.ArgByRef(1));
ASSERT_NE(task.ArgData(0), nullptr);
ASSERT_NE(task.ArgData(1), nullptr);
ASSERT_EQ(resolver.NumPendingTasks(), 0);
}
TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) {
auto ptr = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
CoreWorkerMemoryStoreProvider store(ptr);
LocalDependencyResolver resolver(store);
ObjectID obj1 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
ObjectID obj2 = ObjectID::FromRandom().WithTransportType(TaskTransportType::DIRECT);
auto data = GenerateRandomObject();
TaskSpecification task;
task.GetMutableMessage().add_args()->add_object_ids(obj1.Binary());
task.GetMutableMessage().add_args()->add_object_ids(obj2.Binary());
bool ok = false;
resolver.ResolveDependencies(task, [&ok]() { ok = true; });
ASSERT_EQ(resolver.NumPendingTasks(), 1);
ASSERT_TRUE(!ok);
ASSERT_TRUE(store.Put(*data, obj1).ok());
ASSERT_TRUE(store.Put(*data, obj2).ok());
// Tests that the task proto was rewritten to have inline argument values after
// resolution completes.
ASSERT_TRUE(ok);
ASSERT_FALSE(task.ArgByRef(0));
ASSERT_FALSE(task.ArgByRef(1));
ASSERT_NE(task.ArgData(0), nullptr);
ASSERT_NE(task.ArgData(1), nullptr);
ASSERT_EQ(resolver.NumPendingTasks(), 0);
}
TEST(DirectTaskTranportTest, TestSubmitOneTask) {
MockRayletClient raylet_client;
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto factory = [&](WorkerAddress addr) { return worker_client; };
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
TaskSpecification task;
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
ASSERT_TRUE(submitter.SubmitTask(task).ok());
ASSERT_EQ(raylet_client.num_workers_requested, 1);
ASSERT_EQ(raylet_client.num_workers_returned, 0);
ASSERT_EQ(worker_client->callbacks.size(), 0);
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1234));
ASSERT_EQ(worker_client->callbacks.size(), 1);
worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply());
ASSERT_EQ(raylet_client.num_workers_returned, 1);
}
TEST(DirectTaskTranportTest, TestHandleTaskFailure) {
MockRayletClient raylet_client;
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto factory = [&](WorkerAddress addr) { return worker_client; };
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
TaskSpecification task;
task.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
ASSERT_TRUE(submitter.SubmitTask(task).ok());
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1234));
// Simulate a system failure, i.e., worker died unexpectedly.
worker_client->callbacks[0](Status::IOError("oops"), rpc::PushTaskReply());
ASSERT_EQ(worker_client->callbacks.size(), 1);
ASSERT_EQ(raylet_client.num_workers_returned, 1);
}
TEST(DirectTaskTranportTest, TestConcurrentWorkerLeases) {
MockRayletClient raylet_client;
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto factory = [&](WorkerAddress addr) { return worker_client; };
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
TaskSpecification task1;
TaskSpecification task2;
TaskSpecification task3;
task1.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
task2.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
task3.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
ASSERT_TRUE(submitter.SubmitTask(task1).ok());
ASSERT_TRUE(submitter.SubmitTask(task2).ok());
ASSERT_TRUE(submitter.SubmitTask(task3).ok());
ASSERT_EQ(raylet_client.num_workers_requested, 1);
// Task 1 is pushed; worker 2 is requested.
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1000));
ASSERT_EQ(worker_client->callbacks.size(), 1);
ASSERT_EQ(raylet_client.num_workers_requested, 2);
// Task 2 is pushed; worker 3 is requested.
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1001));
ASSERT_EQ(worker_client->callbacks.size(), 2);
ASSERT_EQ(raylet_client.num_workers_requested, 3);
// Task 3 is pushed; no more workers requested.
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1002));
ASSERT_EQ(worker_client->callbacks.size(), 3);
ASSERT_EQ(raylet_client.num_workers_requested, 3);
// All workers returned.
for (const auto &cb : worker_client->callbacks) {
cb(Status::OK(), rpc::PushTaskReply());
}
ASSERT_EQ(raylet_client.num_workers_returned, 3);
}
TEST(DirectTaskTranportTest, TestReuseWorkerLease) {
MockRayletClient raylet_client;
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto factory = [&](WorkerAddress addr) { return worker_client; };
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
TaskSpecification task1;
TaskSpecification task2;
TaskSpecification task3;
task1.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
task2.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
task3.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
ASSERT_TRUE(submitter.SubmitTask(task1).ok());
ASSERT_TRUE(submitter.SubmitTask(task2).ok());
ASSERT_TRUE(submitter.SubmitTask(task3).ok());
ASSERT_EQ(raylet_client.num_workers_requested, 1);
// Task 1 is pushed.
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1000));
ASSERT_EQ(worker_client->callbacks.size(), 1);
ASSERT_EQ(raylet_client.num_workers_requested, 2);
// Task 1 finishes, Task 2 is scheduled on the same worker.
worker_client->callbacks[0](Status::OK(), rpc::PushTaskReply());
ASSERT_EQ(worker_client->callbacks.size(), 2);
ASSERT_EQ(raylet_client.num_workers_returned, 0);
// Task 2 finishes, Task 3 is scheduled on the same worker.
worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply());
ASSERT_EQ(worker_client->callbacks.size(), 3);
ASSERT_EQ(raylet_client.num_workers_returned, 0);
// Task 3 finishes, the worker is returned.
worker_client->callbacks[2](Status::OK(), rpc::PushTaskReply());
ASSERT_EQ(raylet_client.num_workers_returned, 1);
// The second lease request is returned immediately.
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1001));
ASSERT_EQ(raylet_client.num_workers_returned, 2);
}
TEST(DirectTaskTranportTest, TestWorkerNotReusedOnError) {
MockRayletClient raylet_client;
auto worker_client = std::shared_ptr<MockWorkerClient>(new MockWorkerClient());
auto store = std::shared_ptr<CoreWorkerMemoryStore>(new CoreWorkerMemoryStore());
auto factory = [&](WorkerAddress addr) { return worker_client; };
CoreWorkerDirectTaskSubmitter submitter(raylet_client, factory, store);
TaskSpecification task1;
TaskSpecification task2;
task1.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
task2.GetMutableMessage().set_task_id(TaskID::Nil().Binary());
ASSERT_TRUE(submitter.SubmitTask(task1).ok());
ASSERT_TRUE(submitter.SubmitTask(task2).ok());
ASSERT_EQ(raylet_client.num_workers_requested, 1);
// Task 1 is pushed.
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1000));
ASSERT_EQ(worker_client->callbacks.size(), 1);
ASSERT_EQ(raylet_client.num_workers_requested, 2);
// Task 1 finishes with failure; the worker is returned.
worker_client->callbacks[0](Status::IOError("worker dead"), rpc::PushTaskReply());
ASSERT_EQ(worker_client->callbacks.size(), 1);
ASSERT_EQ(raylet_client.num_workers_returned, 1);
// Task 2 runs successfully on the second worker.
submitter.HandleWorkerLeaseGranted(std::make_pair("localhost", 1001));
ASSERT_EQ(worker_client->callbacks.size(), 2);
worker_client->callbacks[1](Status::OK(), rpc::PushTaskReply());
ASSERT_EQ(raylet_client.num_workers_returned, 2);
}
} // namespace ray
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
@@ -6,14 +6,11 @@ using ray::rpc::ActorTableData;
namespace ray {
CoreWorkerDirectActorTaskSubmitter::CoreWorkerDirectActorTaskSubmitter(
boost::asio::io_service &io_service,
std::unique_ptr<CoreWorkerMemoryStoreProvider> store_provider)
: io_service_(io_service),
client_call_manager_(io_service),
in_memory_store_(std::move(store_provider)) {}
rpc::ClientCallManager &client_call_manager,
CoreWorkerMemoryStoreProvider store_provider)
: client_call_manager_(client_call_manager), in_memory_store_(store_provider) {}
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(
const TaskSpecification &task_spec) {
Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
RAY_LOG(DEBUG) << "Submitting task " << task_spec.TaskId();
RAY_CHECK(task_spec.IsActorTask());
@@ -22,8 +19,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(
const auto task_id = task_spec.TaskId();
const auto num_returns = task_spec.NumReturns();
auto request = std::unique_ptr<rpc::DirectActorAssignTaskRequest>(
new rpc::DirectActorAssignTaskRequest);
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
std::unique_lock<std::mutex> guard(mutex_);
@@ -49,7 +45,7 @@ Status CoreWorkerDirectActorTaskSubmitter::SubmitTask(
// Submit request.
auto &client = rpc_clients_[actor_id];
DirectActorAssignTask(*client, std::move(request), actor_id, task_id, num_returns);
PushActorTask(*client, std::move(request), actor_id, task_id, num_returns);
} else {
// Actor is dead, treat the task as failure.
RAY_CHECK(iter->second.state_ == ActorTableData::DEAD);
@@ -106,8 +102,8 @@ void CoreWorkerDirectActorTaskSubmitter::HandleActorUpdate(
void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks(
const ActorID &actor_id, std::string ip_address, int port) {
std::shared_ptr<rpc::WorkerTaskClient> grpc_client =
std::make_shared<rpc::WorkerTaskClient>(ip_address, port, client_call_manager_);
std::shared_ptr<rpc::CoreWorkerClient> grpc_client =
std::make_shared<rpc::CoreWorkerClient>(ip_address, port, client_call_manager_);
RAY_CHECK(rpc_clients_.emplace(actor_id, std::move(grpc_client)).second);
// Submit all pending requests.
@@ -117,22 +113,20 @@ void CoreWorkerDirectActorTaskSubmitter::ConnectAndSendPendingTasks(
auto request = std::move(requests.front());
auto num_returns = request->task_spec().num_returns();
auto task_id = TaskID::FromBinary(request->task_spec().task_id());
DirectActorAssignTask(*client, std::move(request), actor_id, task_id, num_returns);
PushActorTask(*client, std::move(request), actor_id, task_id, num_returns);
requests.pop_front();
}
}
void CoreWorkerDirectActorTaskSubmitter::DirectActorAssignTask(
rpc::WorkerTaskClient &client,
std::unique_ptr<rpc::DirectActorAssignTaskRequest> request, const ActorID &actor_id,
const TaskID &task_id, int num_returns) {
void CoreWorkerDirectActorTaskSubmitter::PushActorTask(
rpc::CoreWorkerClient &client, std::unique_ptr<rpc::PushTaskRequest> request,
const ActorID &actor_id, const TaskID &task_id, int num_returns) {
RAY_LOG(DEBUG) << "Pushing task " << task_id << " to actor " << actor_id;
waiting_reply_tasks_[actor_id].insert(std::make_pair(task_id, num_returns));
auto status = client.DirectActorAssignTask(
std::move(request),
[this, actor_id, task_id, num_returns](
Status status, const rpc::DirectActorAssignTaskReply &reply) {
auto status = client.PushActorTask(
std::move(request), [this, actor_id, task_id, num_returns](
Status status, const rpc::PushTaskReply &reply) {
{
std::unique_lock<std::mutex> guard(mutex_);
waiting_reply_tasks_[actor_id].erase(task_id);
@@ -156,7 +150,7 @@ void CoreWorkerDirectActorTaskSubmitter::DirectActorAssignTask(
const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
RAY_CHECK_OK(
in_memory_store_->Put(RayObject(nullptr, meta_buffer), object_id));
in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id));
} else {
std::shared_ptr<LocalMemoryBuffer> data_buffer;
if (return_object.data().size() > 0) {
@@ -172,8 +166,8 @@ void CoreWorkerDirectActorTaskSubmitter::DirectActorAssignTask(
reinterpret_cast<const uint8_t *>(return_object.metadata().data())),
return_object.metadata().size());
}
RAY_CHECK_OK(in_memory_store_->Put(RayObject(data_buffer, metadata_buffer),
object_id));
RAY_CHECK_OK(
in_memory_store_.Put(RayObject(data_buffer, metadata_buffer), object_id));
}
}
});
@@ -189,11 +183,11 @@ void CoreWorkerDirectActorTaskSubmitter::TreatTaskAsFailed(
for (int i = 0; i < num_returns; i++) {
const auto object_id = ObjectID::ForTaskReturn(
task_id, /*index=*/i + 1,
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT_ACTOR));
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
std::string meta = std::to_string(static_cast<int>(error_type));
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
RAY_CHECK_OK(in_memory_store_->Put(RayObject(nullptr, meta_buffer), object_id));
RAY_CHECK_OK(in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id));
}
}
@@ -204,20 +198,19 @@ bool CoreWorkerDirectActorTaskSubmitter::IsActorAlive(const ActorID &actor_id) c
return (iter != actor_states_.end() && iter->second.state_ == ActorTableData::ALIVE);
}
CoreWorkerDirectActorTaskReceiver::CoreWorkerDirectActorTaskReceiver(
CoreWorkerDirectTaskReceiver::CoreWorkerDirectTaskReceiver(
WorkerContext &worker_context, boost::asio::io_service &main_io_service,
rpc::GrpcServer &server, const TaskHandler &task_handler,
const std::function<void()> &exit_handler)
const TaskHandler &task_handler, const std::function<void()> &exit_handler)
: worker_context_(worker_context),
task_handler_(task_handler),
exit_handler_(exit_handler),
task_main_io_service_(main_io_service) {}
void CoreWorkerDirectActorTaskReceiver::Init(RayletClient &raylet_client) {
void CoreWorkerDirectTaskReceiver::Init(RayletClient &raylet_client) {
waiter_.reset(new DependencyWaiterImpl(raylet_client));
}
void CoreWorkerDirectActorTaskReceiver::SetMaxActorConcurrency(int max_concurrency) {
void CoreWorkerDirectTaskReceiver::SetMaxActorConcurrency(int max_concurrency) {
if (max_concurrency != max_concurrency_) {
RAY_LOG(INFO) << "Creating new thread pool of size " << max_concurrency;
RAY_CHECK(pool_ == nullptr) << "Cannot change max concurrency at runtime.";
@@ -226,13 +219,13 @@ void CoreWorkerDirectActorTaskReceiver::SetMaxActorConcurrency(int max_concurren
}
}
void CoreWorkerDirectActorTaskReceiver::HandleDirectActorAssignTask(
const rpc::DirectActorAssignTaskRequest &request,
rpc::DirectActorAssignTaskReply *reply, rpc::SendReplyCallback send_reply_callback) {
void CoreWorkerDirectTaskReceiver::HandlePushTask(
const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
rpc::SendReplyCallback send_reply_callback) {
RAY_CHECK(waiter_ != nullptr) << "Must call init() prior to use";
const TaskSpecification task_spec(request.task_spec());
RAY_LOG(DEBUG) << "Received task " << task_spec.TaskId();
if (task_spec.IsActorTask() && !worker_context_.CurrentActorUseDirectCall()) {
if (task_spec.IsActorTask() && !worker_context_.CurrentTaskIsDirectCall()) {
send_reply_callback(Status::Invalid("This actor doesn't accept direct calls."),
nullptr, nullptr);
return;
@@ -257,62 +250,63 @@ void CoreWorkerDirectActorTaskReceiver::HandleDirectActorAssignTask(
task_main_io_service_, *waiter_, pool_)));
it = result.first;
}
it->second->Add(
request.sequence_number(), request.client_processed_up_to(),
[this, reply, send_reply_callback, task_spec]() {
auto num_returns = task_spec.NumReturns();
RAY_CHECK(task_spec.IsActorCreationTask() || task_spec.IsActorTask());
RAY_CHECK(num_returns > 0);
// Decrease to account for the dummy object id.
num_returns--;
it->second->Add(request.sequence_number(), request.client_processed_up_to(),
[this, reply, send_reply_callback, task_spec]() {
auto num_returns = task_spec.NumReturns();
RAY_CHECK(num_returns > 0);
if (task_spec.IsActorCreationTask() || task_spec.IsActorTask()) {
// Decrease to account for the dummy object id.
num_returns--;
}
// TODO(edoakes): resource IDs are currently kept track of in the raylet,
// need to come up with a solution for this.
ResourceMappingType resource_ids;
std::vector<std::shared_ptr<RayObject>> return_objects;
auto status = task_handler_(task_spec, resource_ids, &return_objects);
if (status.IsSystemExit()) {
// In Python, SystemExit can only be raised on the main thread. To work
// around this when we are executing tasks on worker threads, we re-post the
// exit event explicitly on the main thread.
task_main_io_service_.post([this]() { exit_handler_(); });
return;
}
RAY_CHECK(return_objects.size() == num_returns)
<< return_objects.size() << " " << num_returns;
// TODO(edoakes): resource IDs are currently kept track of in the
// raylet, need to come up with a solution for this.
ResourceMappingType resource_ids;
std::vector<std::shared_ptr<RayObject>> return_objects;
auto status = task_handler_(task_spec, resource_ids, &return_objects);
if (status.IsSystemExit()) {
// In Python, SystemExit can only be raised on the main thread. To
// work around this when we are executing tasks on worker threads,
// we re-post the exit event explicitly on the main thread.
task_main_io_service_.post([this]() { exit_handler_(); });
return;
}
RAY_CHECK(return_objects.size() == num_returns)
<< return_objects.size() << " " << num_returns;
for (size_t i = 0; i < return_objects.size(); i++) {
auto return_object = reply->add_return_objects();
ObjectID id = ObjectID::ForTaskReturn(
task_spec.TaskId(), /*index=*/i + 1,
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT_ACTOR));
return_object->set_object_id(id.Binary());
for (size_t i = 0; i < return_objects.size(); i++) {
auto return_object = reply->add_return_objects();
ObjectID id = ObjectID::ForTaskReturn(
task_spec.TaskId(), /*index=*/i + 1,
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
return_object->set_object_id(id.Binary());
// The object is nullptr if it already existed in the object store.
const auto &result = return_objects[i];
if (result == nullptr || result->GetData()->IsPlasmaBuffer()) {
return_object->set_in_plasma(true);
} else {
if (result->GetData() != nullptr) {
return_object->set_data(result->GetData()->Data(),
result->GetData()->Size());
}
if (result->GetMetadata() != nullptr) {
return_object->set_metadata(result->GetMetadata()->Data(),
result->GetMetadata()->Size());
}
}
}
// The object is nullptr if it already existed in the object store.
const auto &result = return_objects[i];
if (result == nullptr || result->GetData()->IsPlasmaBuffer()) {
return_object->set_in_plasma(true);
} else {
if (result->GetData() != nullptr) {
return_object->set_data(result->GetData()->Data(),
result->GetData()->Size());
}
if (result->GetMetadata() != nullptr) {
return_object->set_metadata(result->GetMetadata()->Data(),
result->GetMetadata()->Size());
}
}
}
send_reply_callback(status, nullptr, nullptr);
},
[send_reply_callback]() {
send_reply_callback(Status::Invalid("client cancelled rpc"), nullptr, nullptr);
},
dependencies);
send_reply_callback(status, nullptr, nullptr);
},
[send_reply_callback]() {
send_reply_callback(Status::Invalid("client cancelled rpc"), nullptr,
nullptr);
},
dependencies);
}
void CoreWorkerDirectActorTaskReceiver::HandleDirectActorCallArgWaitComplete(
void CoreWorkerDirectTaskReceiver::HandleDirectActorCallArgWaitComplete(
const rpc::DirectActorCallArgWaitCompleteRequest &request,
rpc::DirectActorCallArgWaitCompleteReply *reply,
rpc::SendReplyCallback send_reply_callback) {
@@ -15,7 +15,7 @@
#include "ray/core_worker/store_provider/memory_store_provider.h"
#include "ray/gcs/redis_gcs_client.h"
#include "ray/rpc/grpc_server.h"
#include "ray/rpc/worker/worker_client.h"
#include "ray/rpc/worker/core_worker_client.h"
namespace ray {
@@ -40,15 +40,14 @@ struct ActorStateData {
// This class is thread-safe.
class CoreWorkerDirectActorTaskSubmitter {
public:
CoreWorkerDirectActorTaskSubmitter(
boost::asio::io_service &io_service,
std::unique_ptr<CoreWorkerMemoryStoreProvider> store_provider);
CoreWorkerDirectActorTaskSubmitter(rpc::ClientCallManager &client_call_manager,
CoreWorkerMemoryStoreProvider store_provider);
/// Submit a task to an actor for execution.
///
/// \param[in] task The task spec to submit.
/// \return Status::Invalid if the task is not yet supported.
Status SubmitTask(const TaskSpecification &task_spec);
Status SubmitTask(TaskSpecification task_spec);
/// Handle an update about an actor.
///
@@ -67,10 +66,9 @@ class CoreWorkerDirectActorTaskSubmitter {
/// \param[in] task_id The ID of a task.
/// \param[in] num_returns Number of return objects.
/// \return Void.
void DirectActorAssignTask(rpc::WorkerTaskClient &client,
std::unique_ptr<rpc::DirectActorAssignTaskRequest> request,
const ActorID &actor_id, const TaskID &task_id,
int num_returns);
void PushActorTask(rpc::CoreWorkerClient &client,
std::unique_ptr<rpc::PushTaskRequest> request,
const ActorID &actor_id, const TaskID &task_id, int num_returns);
/// Treat a task as failed.
///
@@ -98,11 +96,8 @@ class CoreWorkerDirectActorTaskSubmitter {
/// \return Whether this actor is alive.
bool IsActorAlive(const ActorID &actor_id) const;
/// The IO event loop.
boost::asio::io_service &io_service_;
/// The `ClientCallManager` object that is shared by all `DirectActorClient`s.
rpc::ClientCallManager client_call_manager_;
/// The shared `ClientCallManager` object.
rpc::ClientCallManager &client_call_manager_;
/// Mutex to proect the various maps below.
mutable std::mutex mutex_;
@@ -115,18 +110,17 @@ class CoreWorkerDirectActorTaskSubmitter {
///
/// TODO(zhijunfu): this will be moved into `actor_states_` later when we can
/// subscribe updates for a specific actor.
std::unordered_map<ActorID, std::shared_ptr<rpc::WorkerTaskClient>> rpc_clients_;
std::unordered_map<ActorID, std::shared_ptr<rpc::CoreWorkerClient>> rpc_clients_;
/// Map from actor id to the actor's pending requests.
std::unordered_map<ActorID,
std::list<std::unique_ptr<rpc::DirectActorAssignTaskRequest>>>
std::unordered_map<ActorID, std::list<std::unique_ptr<rpc::PushTaskRequest>>>
pending_requests_;
/// Map from actor id to the tasks that are waiting for reply.
std::unordered_map<ActorID, std::unordered_map<TaskID, int>> waiting_reply_tasks_;
/// The store provider.
std::unique_ptr<CoreWorkerMemoryStoreProvider> in_memory_store_;
CoreWorkerMemoryStoreProvider in_memory_store_;
friend class CoreWorkerTest;
};
@@ -235,6 +229,9 @@ class SchedulingQueue {
void Add(int64_t seq_no, int64_t client_processed_up_to,
std::function<void()> accept_request, std::function<void()> reject_request,
const std::vector<ObjectID> &dependencies = {}) {
if (seq_no == -1) {
seq_no = next_seq_no_; // A value of -1 means no ordering constraint.
}
RAY_CHECK(boost::this_thread::get_id() == main_thread_id_);
if (client_processed_up_to >= next_seq_no_) {
RAY_LOG(ERROR) << "client skipping requests " << next_seq_no_ << " to "
@@ -329,29 +326,27 @@ class SchedulingQueue {
friend class SchedulingQueueTest;
};
class CoreWorkerDirectActorTaskReceiver {
class CoreWorkerDirectTaskReceiver {
public:
using TaskHandler = std::function<Status(
const TaskSpecification &task_spec, const ResourceMappingType &resource_ids,
std::vector<std::shared_ptr<RayObject>> *return_objects)>;
CoreWorkerDirectActorTaskReceiver(WorkerContext &worker_context,
boost::asio::io_service &main_io_service,
rpc::GrpcServer &server,
const TaskHandler &task_handler,
const std::function<void()> &exit_handler);
CoreWorkerDirectTaskReceiver(WorkerContext &worker_context,
boost::asio::io_service &main_io_service,
const TaskHandler &task_handler,
const std::function<void()> &exit_handler);
/// Initialize this receiver. This must be called prior to use.
void Init(RayletClient &client);
/// Handle a `DirectActorAssignTask` request.
/// Handle a `PushTask` request.
///
/// \param[in] request The request message.
/// \param[out] reply The reply message.
/// \param[in] send_reply_callback The callback to be called when the request is done.
void HandleDirectActorAssignTask(const rpc::DirectActorAssignTaskRequest &request,
rpc::DirectActorAssignTaskReply *reply,
rpc::SendReplyCallback send_reply_callback);
void HandlePushTask(const rpc::PushTaskRequest &request, rpc::PushTaskReply *reply,
rpc::SendReplyCallback send_reply_callback);
/// Handle a `DirectActorCallArgWaitComplete` request.
///
@@ -0,0 +1,186 @@
#include "ray/core_worker/transport/direct_task_transport.h"
namespace ray {
void DoInlineObjectValue(const ObjectID &obj_id, std::shared_ptr<RayObject> value,
TaskSpecification &task) {
auto &msg = task.GetMutableMessage();
bool found = false;
for (size_t i = 0; i < task.NumArgs(); i++) {
auto count = task.ArgIdCount(i);
if (count > 0) {
const auto &id = task.ArgId(i, 0);
if (id == obj_id) {
auto *mutable_arg = msg.mutable_args(i);
mutable_arg->clear_object_ids();
if (value->HasData()) {
const auto &data = value->GetData();
mutable_arg->set_data(data->Data(), data->Size());
}
if (value->HasMetadata()) {
const auto &metadata = value->GetMetadata();
mutable_arg->set_metadata(metadata->Data(), metadata->Size());
}
found = true;
}
}
}
RAY_CHECK(found) << "obj id " << obj_id << " not found";
}
void LocalDependencyResolver::ResolveDependencies(const TaskSpecification &task,
std::function<void()> on_complete) {
absl::flat_hash_set<ObjectID> local_dependencies;
for (size_t i = 0; i < task.NumArgs(); i++) {
auto count = task.ArgIdCount(i);
if (count > 0) {
RAY_CHECK(count <= 1) << "multi args not implemented";
const auto &id = task.ArgId(i, 0);
if (id.IsDirectCallType()) {
local_dependencies.insert(id);
}
}
}
if (local_dependencies.empty()) {
on_complete();
return;
}
// This is deleted when the last dependency fetch callback finishes.
std::shared_ptr<TaskState> state =
std::shared_ptr<TaskState>(new TaskState{task, std::move(local_dependencies)});
num_pending_ += 1;
for (const auto &obj_id : state->local_dependencies) {
in_memory_store_.GetAsync(
obj_id, [this, state, obj_id, on_complete](std::shared_ptr<RayObject> obj) {
RAY_CHECK(obj != nullptr);
bool complete = false;
{
absl::MutexLock lock(&mu_);
state->local_dependencies.erase(obj_id);
DoInlineObjectValue(obj_id, obj, state->task);
if (state->local_dependencies.empty()) {
complete = true;
num_pending_ -= 1;
}
}
if (complete) {
on_complete();
}
});
}
}
Status CoreWorkerDirectTaskSubmitter::SubmitTask(TaskSpecification task_spec) {
resolver_.ResolveDependencies(task_spec, [this, task_spec]() {
// TODO(ekl) should have a queue per distinct resource type required
absl::MutexLock lock(&mu_);
RequestNewWorkerIfNeeded(task_spec);
queued_tasks_.push_back(task_spec);
// The task is now queued and will be picked up by the next leased or newly
// idle worker. We are guaranteed a worker will show up since we called
// RequestNewWorkerIfNeeded() earlier while holding mu_.
});
return Status::OK();
}
void CoreWorkerDirectTaskSubmitter::HandleWorkerLeaseGranted(const WorkerAddress addr) {
// Setup client state for this worker.
{
absl::MutexLock lock(&mu_);
worker_request_pending_ = false;
auto it = client_cache_.find(addr);
if (it == client_cache_.end()) {
client_cache_[addr] =
std::shared_ptr<rpc::CoreWorkerClientInterface>(client_factory_(addr));
RAY_LOG(INFO) << "Connected to " << addr.first << ":" << addr.second;
}
}
// Try to assign it work.
OnWorkerIdle(addr, /*error=*/false);
}
void CoreWorkerDirectTaskSubmitter::OnWorkerIdle(const WorkerAddress &addr,
bool was_error) {
absl::MutexLock lock(&mu_);
if (queued_tasks_.empty() || was_error) {
RAY_CHECK_OK(lease_client_.ReturnWorker(addr.second));
} else {
auto &client = *client_cache_[addr];
PushNormalTask(addr, client, queued_tasks_.front());
queued_tasks_.pop_front();
}
// We have a queue of tasks, try to request more workers.
if (!queued_tasks_.empty()) {
RequestNewWorkerIfNeeded(queued_tasks_.front());
}
}
void CoreWorkerDirectTaskSubmitter::RequestNewWorkerIfNeeded(
const TaskSpecification &resource_spec) {
if (worker_request_pending_) {
return;
}
RAY_CHECK_OK(lease_client_.RequestWorkerLease(resource_spec));
worker_request_pending_ = true;
}
void CoreWorkerDirectTaskSubmitter::TreatTaskAsFailed(const TaskID &task_id,
int num_returns,
const rpc::ErrorType &error_type) {
RAY_LOG(DEBUG) << "Treat task as failed. task_id: " << task_id
<< ", error_type: " << ErrorType_Name(error_type);
for (int i = 0; i < num_returns; i++) {
const auto object_id = ObjectID::ForTaskReturn(
task_id, /*index=*/i + 1,
/*transport_type=*/static_cast<int>(TaskTransportType::DIRECT));
std::string meta = std::to_string(static_cast<int>(error_type));
auto metadata = const_cast<uint8_t *>(reinterpret_cast<const uint8_t *>(meta.data()));
auto meta_buffer = std::make_shared<LocalMemoryBuffer>(metadata, meta.size());
RAY_CHECK_OK(in_memory_store_.Put(RayObject(nullptr, meta_buffer), object_id));
}
}
// TODO(ekl) consider reconsolidating with DirectActorTransport.
void CoreWorkerDirectTaskSubmitter::PushNormalTask(const WorkerAddress &addr,
rpc::CoreWorkerClientInterface &client,
TaskSpecification &task_spec) {
auto task_id = task_spec.TaskId();
auto num_returns = task_spec.NumReturns();
auto request = std::unique_ptr<rpc::PushTaskRequest>(new rpc::PushTaskRequest);
request->mutable_task_spec()->Swap(&task_spec.GetMutableMessage());
auto status = client.PushNormalTask(
std::move(request),
[this, task_id, num_returns, addr](Status status, const rpc::PushTaskReply &reply) {
OnWorkerIdle(addr, /*error=*/!status.ok());
if (!status.ok()) {
TreatTaskAsFailed(task_id, num_returns, rpc::ErrorType::WORKER_DIED);
return;
}
for (int i = 0; i < reply.return_objects_size(); i++) {
const auto &return_object = reply.return_objects(i);
ObjectID object_id = ObjectID::FromBinary(return_object.object_id());
std::shared_ptr<LocalMemoryBuffer> data_buffer;
if (return_object.data().size() > 0) {
data_buffer = std::make_shared<LocalMemoryBuffer>(
const_cast<uint8_t *>(
reinterpret_cast<const uint8_t *>(return_object.data().data())),
return_object.data().size());
}
std::shared_ptr<LocalMemoryBuffer> metadata_buffer;
if (return_object.metadata().size() > 0) {
metadata_buffer = std::make_shared<LocalMemoryBuffer>(
const_cast<uint8_t *>(
reinterpret_cast<const uint8_t *>(return_object.metadata().data())),
return_object.metadata().size());
}
RAY_CHECK_OK(
in_memory_store_.Put(RayObject(data_buffer, metadata_buffer), object_id));
}
});
RAY_CHECK_OK(status);
}
}; // namespace ray
@@ -0,0 +1,128 @@
#ifndef RAY_CORE_WORKER_DIRECT_TASK_H
#define RAY_CORE_WORKER_DIRECT_TASK_H
#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
#include "ray/common/id.h"
#include "ray/common/ray_object.h"
#include "ray/core_worker/context.h"
#include "ray/core_worker/store_provider/memory_store_provider.h"
#include "ray/core_worker/transport/direct_actor_transport.h"
#include "ray/raylet/raylet_client.h"
#include "ray/rpc/worker/core_worker_client.h"
namespace ray {
struct TaskState {
/// The task to be run.
TaskSpecification task;
/// The remaining dependencies to resolve for this task.
absl::flat_hash_set<ObjectID> local_dependencies;
};
// This class is thread-safe.
class LocalDependencyResolver {
public:
LocalDependencyResolver(CoreWorkerMemoryStoreProvider &store_provider)
: in_memory_store_(store_provider), num_pending_(0) {}
/// Resolve all local and remote dependencies for the task, calling the specified
/// callback when done. Direct call ids in the task specification will be resolved
/// to concrete values and inlined.
//
/// Note: This method **will mutate** the given TaskSpecification.
///
/// Postcondition: all direct call ids in arguments are converted to values.
void ResolveDependencies(const TaskSpecification &task,
std::function<void()> on_complete);
/// Return the number of tasks pending dependency resolution.
/// TODO(ekl) this should be exposed in worker stats.
int NumPendingTasks() const { return num_pending_; }
private:
/// The store provider.
CoreWorkerMemoryStoreProvider in_memory_store_;
/// Number of tasks pending dependency resolution.
std::atomic<int> num_pending_;
/// Protects against concurrent access to internal state.
absl::Mutex mu_;
};
typedef std::pair<std::string, int> WorkerAddress;
typedef std::function<std::shared_ptr<rpc::CoreWorkerClientInterface>(WorkerAddress)>
ClientFactoryFn;
// This class is thread-safe.
class CoreWorkerDirectTaskSubmitter {
public:
CoreWorkerDirectTaskSubmitter(WorkerLeaseInterface &lease_client,
ClientFactoryFn client_factory,
CoreWorkerMemoryStoreProvider store_provider)
: lease_client_(lease_client),
client_factory_(client_factory),
in_memory_store_(store_provider),
resolver_(in_memory_store_) {}
/// Schedule a task for direct submission to a worker.
///
/// \param[in] task_spec The task to schedule.
Status SubmitTask(TaskSpecification task_spec);
/// Callback for when the raylet grants us a worker lease. The worker is returned
/// to the raylet once it finishes its task and either the lease term has
/// expired, or there is no more work it can take on.
///
/// \param[in] addr The (addr, port) pair identifying the worker.
void HandleWorkerLeaseGranted(const WorkerAddress addr);
private:
/// Schedule more work onto an idle worker or return it back to the raylet if
/// no more tasks are queued for submission. If an error was encountered
/// processing the worker, we don't attempt to re-use the worker.
void OnWorkerIdle(const WorkerAddress &addr, bool was_error);
/// Request a new worker from the raylet if no such requests are currently in
/// flight.
void RequestNewWorkerIfNeeded(const TaskSpecification &resource_spec)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
/// Push a task to a specific worker.
void PushNormalTask(const WorkerAddress &addr, rpc::CoreWorkerClientInterface &client,
TaskSpecification &task_spec);
/// Mark a direct call as failed by storing errors for its return objects.
void TreatTaskAsFailed(const TaskID &task_id, int num_returns,
const rpc::ErrorType &error_type);
// Client that can be used to lease and return workers.
WorkerLeaseInterface &lease_client_;
/// Factory for producing new core worker clients.
ClientFactoryFn client_factory_;
/// The store provider.
CoreWorkerMemoryStoreProvider in_memory_store_;
/// Resolve local and remote dependencies;
LocalDependencyResolver resolver_;
// Protects task submission state below.
absl::Mutex mu_;
/// Cache of gRPC clients to other workers.
absl::flat_hash_map<WorkerAddress, std::shared_ptr<rpc::CoreWorkerClientInterface>>
client_cache_ GUARDED_BY(mu_);
// Whether we have a request to the Raylet to acquire a new worker in flight.
bool worker_request_pending_ GUARDED_BY(mu_) = false;
// Tasks that are queued for execution in this submitter..
std::deque<TaskSpecification> queued_tasks_ GUARDED_BY(mu_);
};
}; // namespace ray
#endif // RAY_CORE_WORKER_DIRECT_TASK_H
@@ -5,7 +5,7 @@
#include "ray/common/ray_object.h"
#include "ray/raylet/raylet_client.h"
#include "ray/rpc/worker/worker_server.h"
#include "ray/rpc/worker/core_worker_server.h"
namespace ray {
+2
View File
@@ -66,6 +66,8 @@ message TaskSpec {
// Task specification for an actor task.
// This field is only valid when `type == ACTOR_TASK`.
ActorTaskSpec actor_task_spec = 14;
// Whether this task is a direct call task.
bool is_direct_call = 15;
}
// Argument in the task.
+22 -9
View File
@@ -56,22 +56,24 @@ message ReturnObject {
bytes metadata = 4;
}
message DirectActorAssignTaskRequest {
message PushTaskRequest {
// The task to be pushed.
TaskSpec task_spec = 1;
// The sequence number of the task for this client. This must increase
// sequentially starting from zero for each actor handle. The server
// will guarantee tasks execute in this sequence, waiting for any
// out-of-order request messages to arrive as necessary.
// If set to -1, ordering is disabled and the task executes immediately.
// This mode of behaviour is used for direct task submission only.
int64 sequence_number = 2;
// The max sequence number the client has processed responses for. This
// is a performance optimization that allows the client to tell the server
// to cancel any DirectActorAssignTaskRequests with seqno <= this value, rather than
// to cancel any PushTaskRequests with seqno <= this value, rather than
// waiting for the server to time out waiting for missing messages.
int64 client_processed_up_to = 3;
}
message DirectActorAssignTaskReply {
message PushTaskReply {
// The returned objects.
repeated ReturnObject return_objects = 1;
}
@@ -85,13 +87,24 @@ message DirectActorCallArgWaitCompleteRequest {
message DirectActorCallArgWaitCompleteReply {
}
service WorkerService {
// Push a task to a worker.
message WorkerLeaseGrantedRequest {
// Address of the leased worker.
string address = 1;
// Port of the leased worker.
int32 port = 2;
}
message WorkerLeaseGrantedReply {
}
service CoreWorkerService {
// Push a task to a worker from the raylet.
rpc AssignTask(AssignTaskRequest) returns (AssignTaskReply);
// Push a task to a direct-call actor.
rpc DirectActorAssignTask(DirectActorAssignTaskRequest)
returns (DirectActorAssignTaskReply);
// Notify wait for direct actor call args has completed.
// Push a task directly to this worker from another.
rpc PushTask(PushTaskRequest) returns (PushTaskReply);
// Reply from raylet that wait for direct actor call args has completed.
rpc DirectActorCallArgWaitComplete(DirectActorCallArgWaitCompleteRequest)
returns (DirectActorCallArgWaitCompleteReply);
// Reply from raylet to fulfill a worker lease request.
rpc WorkerLeaseGranted(WorkerLeaseGrantedRequest) returns (WorkerLeaseGrantedReply);
}
+12
View File
@@ -74,6 +74,10 @@ enum MessageType:int {
SetResourceRequest,
// Update the active set of object IDs in use on this worker.
ReportActiveObjectIDs,
// Request a worker from the raylet with the specified resources.
RequestWorkerLease,
// Returns a worker to the raylet.
ReturnWorker,
}
table TaskExecutionSpecification {
@@ -91,6 +95,14 @@ table Task {
task_execution_spec: TaskExecutionSpecification;
}
table WorkerLeaseRequest {
resource_spec: string;
}
table ReturnWorkerRequest {
worker_port: int;
}
// This message describes a given resource that is reserved for a worker.
table ResourceIdSetInfo {
// The name of the resource.
+2 -1
View File
@@ -139,7 +139,8 @@ static inline Task ExampleTask(const std::vector<ObjectID> &arguments,
uint64_t num_returns) {
TaskSpecBuilder builder;
builder.SetCommonTaskSpec(RandomTaskId(), Language::PYTHON, {"", "", ""}, JobID::Nil(),
RandomTaskId(), 0, RandomTaskId(), num_returns, {}, {});
RandomTaskId(), 0, RandomTaskId(), num_returns, false, {},
{});
for (const auto &arg : arguments) {
builder.AddByRefArg(arg);
}
+54 -1
View File
@@ -2,6 +2,7 @@
#include <cctype>
#include <fstream>
#include <memory>
#include "ray/common/status.h"
@@ -895,6 +896,12 @@ void NodeManager::ProcessClientMessage(
// because it's already disconnected.
return;
} break;
case protocol::MessageType::RequestWorkerLease: {
ProcessRequestWorkerLeaseMessage(client, message_data);
} break;
case protocol::MessageType::ReturnWorker: {
ProcessReturnWorkerMessage(message_data);
} break;
case protocol::MessageType::SetResourceRequest: {
ProcessSetResourceRequest(client, message_data);
} break;
@@ -1031,6 +1038,10 @@ void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_loca
void NodeManager::HandleWorkerAvailable(
const std::shared_ptr<LocalClientConnection> &client) {
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
HandleWorkerAvailable(worker);
}
void NodeManager::HandleWorkerAvailable(const std::shared_ptr<Worker> &worker) {
RAY_CHECK(worker);
// If the worker was assigned a task, mark it as finished.
if (!worker->GetAssignedTaskId().IsNil()) {
@@ -1172,6 +1183,43 @@ void NodeManager::ProcessDisconnectClientMessage(
// these can be leaked.
}
void NodeManager::ProcessRequestWorkerLeaseMessage(
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data) {
// Read the resource spec submitted by the client.
auto fbs_message = flatbuffers::GetRoot<protocol::WorkerLeaseRequest>(message_data);
rpc::Task task_message;
RAY_CHECK(task_message.mutable_task_spec()->ParseFromArray(
fbs_message->resource_spec()->data(), fbs_message->resource_spec()->size()));
// Override the task dispatch to call back to the client instead of executing the
// task directly on the worker. TODO(ekl) handle spilling case
Task task(task_message);
task.OnDispatchInstead([this, client](const std::shared_ptr<void> granted,
const std::string &address, int port) {
std::shared_ptr<Worker> client_worker = worker_pool_.GetRegisteredWorker(client);
if (client_worker == nullptr) {
client_worker = worker_pool_.GetRegisteredDriver(client);
}
if (client_worker == nullptr) {
RAY_LOG(FATAL) << "TODO: Lost worker for lease request " << client;
} else {
client_worker->WorkerLeaseGranted(address, port);
leased_workers_[port] = std::static_pointer_cast<Worker>(granted);
}
});
SubmitTask(task, Lineage());
}
void NodeManager::ProcessReturnWorkerMessage(const uint8_t *message_data) {
// Read the resource spec submitted by the client.
auto fbs_message = flatbuffers::GetRoot<protocol::ReturnWorkerRequest>(message_data);
auto worker_port = fbs_message->worker_port();
RAY_LOG(DEBUG) << "Return worker " << worker_port;
std::shared_ptr<Worker> worker = leased_workers_[worker_port];
leased_workers_.erase(worker_port);
HandleWorkerAvailable(worker);
}
void NodeManager::ProcessFetchOrReconstructMessage(
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data) {
auto message = flatbuffers::GetRoot<protocol::FetchOrReconstruct>(message_data);
@@ -1955,7 +2003,12 @@ bool NodeManager::AssignTask(const Task &task) {
ResourceIdSet resource_id_set =
worker->GetTaskResourceIds().Plus(worker->GetLifetimeResourceIds());
worker->AssignTask(task, resource_id_set, finish_assign_task_callback);
if (task.OnDispatch() != nullptr) {
task.OnDispatch()(worker, initial_config_.node_manager_address, worker->Port());
finish_assign_task_callback(Status::OK());
} else {
worker->AssignTask(task, resource_id_set, finish_assign_task_callback);
}
// We assigned this task to a worker.
// (Note this means that we sent the task to the worker. The assignment
+24 -1
View File
@@ -394,6 +394,12 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// \return Void.
void HandleWorkerAvailable(const std::shared_ptr<LocalClientConnection> &client);
/// Handle the case that a worker is available.
///
/// \param worker The pointer to the worker
/// \return Void.
void HandleWorkerAvailable(const std::shared_ptr<Worker> &worker);
/// Handle a client that has disconnected. This can be called multiple times
/// on the same client because this is triggered both when a client
/// disconnects and when the node manager fails to write a message to the
@@ -406,6 +412,20 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
const std::shared_ptr<LocalClientConnection> &client,
bool intentional_disconnect = false);
/// Process client message of RequestWorkerLease
///
/// \param client The client that sent the message.
/// \param message_data A pointer to the message data.
/// \return Void.
void ProcessRequestWorkerLeaseMessage(
const std::shared_ptr<LocalClientConnection> &client, const uint8_t *message_data);
/// Process client message of ReturnWorkerMessage
///
/// \param message_data A pointer to the message data.
/// \return Void.
void ProcessReturnWorkerMessage(const uint8_t *message_data);
/// Process client message of FetchOrReconstruct
///
/// \param client The client that sent the message.
@@ -574,12 +594,15 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
rpc::NodeManagerGrpcService node_manager_service_;
/// The `ClientCallManager` object that is shared by all `NodeManagerClient`s
/// as well as all `WorkerTaskClient`s.
/// as well as all `CoreWorkerClient`s.
rpc::ClientCallManager client_call_manager_;
/// Map from node ids to clients of the remote node managers.
std::unordered_map<ClientID, std::unique_ptr<rpc::NodeManagerClient>>
remote_node_manager_clients_;
/// Map of workers leased out to direct call clients.
std::unordered_map<int, std::shared_ptr<Worker>> leased_workers_;
};
} // namespace raylet
+16
View File
@@ -374,3 +374,19 @@ ray::Status RayletClient::ReportActiveObjectIDs(
return conn_->WriteMessage(MessageType::ReportActiveObjectIDs, &fbb);
}
ray::Status RayletClient::RequestWorkerLease(
const ray::TaskSpecification &resource_spec) {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateWorkerLeaseRequest(
fbb, fbb.CreateString(resource_spec.Serialize()));
fbb.Finish(message);
return conn_->WriteMessage(MessageType::RequestWorkerLease, &fbb);
}
ray::Status RayletClient::ReturnWorker(int worker_port) {
flatbuffers::FlatBufferBuilder fbb;
auto message = ray::protocol::CreateReturnWorkerRequest(fbb, worker_port);
fbb.Finish(message);
return conn_->WriteMessage(MessageType::ReturnWorker, &fbb);
}
+21 -1
View File
@@ -63,7 +63,21 @@ class RayletConnection {
std::mutex write_mutex_;
};
class RayletClient {
/// Interface for leasing workers. Abstract for testing.
class WorkerLeaseInterface {
public:
/// Requests a worker from the raylet. The callback will be sent via gRPC.
/// \param resource_spec Resources that should be allocated for the worker.
/// \return ray::Status
virtual ray::Status RequestWorkerLease(const ray::TaskSpecification &resource_spec) = 0;
/// Returns a worker to the raylet.
/// \param worker_port The local port of the worker on the raylet node.
/// \return ray::Status
virtual ray::Status ReturnWorker(int worker_port) = 0;
};
class RayletClient : public WorkerLeaseInterface {
public:
/// Connect to the raylet.
///
@@ -185,6 +199,12 @@ class RayletClient {
/// \return ray::Status
ray::Status ReportActiveObjectIDs(const std::unordered_set<ObjectID> &object_ids);
/// Implements WorkerLeaseInterface.
ray::Status RequestWorkerLease(const ray::TaskSpecification &resource_spec) override;
/// Implements WorkerLeaseInterface.
ray::Status ReturnWorker(int worker_port) override;
Language GetLanguage() const { return language_; }
WorkerID GetWorkerID() const { return worker_id_; }
@@ -76,7 +76,8 @@ static inline Task ExampleTask(const std::vector<ObjectID> &arguments,
uint64_t num_returns) {
TaskSpecBuilder builder;
builder.SetCommonTaskSpec(RandomTaskId(), Language::PYTHON, {"", "", ""}, JobID::Nil(),
RandomTaskId(), 0, RandomTaskId(), num_returns, {}, {});
RandomTaskId(), 0, RandomTaskId(), num_returns, false, {},
{});
for (const auto &arg : arguments) {
builder.AddByRefArg(arg);
}
+20 -2
View File
@@ -25,8 +25,8 @@ Worker::Worker(const WorkerID &worker_id, pid_t pid, const Language &language, i
client_call_manager_(client_call_manager),
is_detached_actor_(false) {
if (port_ > 0) {
rpc_client_ = std::unique_ptr<rpc::WorkerTaskClient>(
new rpc::WorkerTaskClient("127.0.0.1", port_, client_call_manager_));
rpc_client_ = std::unique_ptr<rpc::CoreWorkerClient>(
new rpc::CoreWorkerClient("127.0.0.1", port_, client_call_manager_));
}
}
@@ -172,6 +172,24 @@ void Worker::DirectActorCallArgWaitComplete(int64_t tag) {
}
}
void Worker::WorkerLeaseGranted(const std::string &address, int port) {
RAY_CHECK(!address.empty());
RAY_CHECK(port_ > 0);
rpc::WorkerLeaseGrantedRequest request;
request.set_address(address);
request.set_port(port);
auto status = rpc_client_->WorkerLeaseGranted(
request, [address, port](Status status, const rpc::WorkerLeaseGrantedReply &reply) {
if (!status.ok()) {
RAY_LOG(ERROR) << "Failed to reply to lease request: " << status.ToString()
<< " for " << address << ":" << port;
}
});
if (!status.ok()) {
RAY_LOG(ERROR) << "Failed to reply to lease request: " << status.ToString();
}
}
} // namespace raylet
} // end namespace ray
+4 -3
View File
@@ -8,7 +8,7 @@
#include "ray/common/task/scheduling_resources.h"
#include "ray/common/task/task.h"
#include "ray/common/task/task_common.h"
#include "ray/rpc/worker/worker_client.h"
#include "ray/rpc/worker/core_worker_client.h"
#include <unistd.h> // pid_t
@@ -67,6 +67,7 @@ class Worker {
void AssignTask(const Task &task, const ResourceIdSet &resource_id_set,
const std::function<void(Status)> finish_assign_callback);
void DirectActorCallArgWaitComplete(int64_t tag);
void WorkerLeaseGranted(const std::string &address, int port);
private:
/// The worker's ID.
@@ -100,11 +101,11 @@ class Worker {
std::unordered_set<TaskID> blocked_task_ids_;
/// The set of object IDs that are currently in use on the worker.
std::unordered_set<ObjectID> active_object_ids_;
/// The `ClientCallManager` object that is shared by `WorkerTaskClient` from all
/// The `ClientCallManager` object that is shared by `CoreWorkerClient` from all
/// workers.
rpc::ClientCallManager &client_call_manager_;
/// The rpc client to send tasks to this worker.
std::unique_ptr<rpc::WorkerTaskClient> rpc_client_;
std::unique_ptr<rpc::CoreWorkerClient> rpc_client_;
/// Whether the worker is detached. This is applies when the worker is actor.
/// Detached actor means the actor's creator can exit without killing this actor.
bool is_detached_actor_;
@@ -1,5 +1,5 @@
#ifndef RAY_RPC_WORKER_CLIENT_H
#define RAY_RPC_WORKER_CLIENT_H
#ifndef RAY_RPC_CORE_WORKER_CLIENT_H
#define RAY_RPC_CORE_WORKER_CLIENT_H
#include <deque>
#include <memory>
@@ -25,7 +25,7 @@ const int64_t kMaxBytesInFlight = 16 * 1024 * 1024;
const int64_t kBaseRequestSize = 1024;
/// Get the estimated size in bytes of the given task.
const static int64_t RequestSizeInBytes(const DirectActorAssignTaskRequest &request) {
const static int64_t RequestSizeInBytes(const PushTaskRequest &request) {
int64_t size = kBaseRequestSize;
for (auto &arg : request.task_spec().args()) {
size += arg.data().size();
@@ -33,44 +33,87 @@ const static int64_t RequestSizeInBytes(const DirectActorAssignTaskRequest &requ
return size;
}
/// Abstract client interface for testing.
class CoreWorkerClientInterface {
public:
/// This is called by the Raylet to assign a task to the worker.
///
/// \param[in] request The request message.
/// \param[in] callback The callback function that handles reply.
/// \return if the rpc call succeeds
virtual ray::Status AssignTask(const AssignTaskRequest &request,
const ClientCallback<AssignTaskReply> &callback) {
return Status::NotImplemented("");
}
/// Push an actor task directly from worker to worker.
///
/// \param[in] request The request message.
/// \param[in] callback The callback function that handles reply.
/// \return if the rpc call succeeds
virtual ray::Status PushActorTask(std::unique_ptr<PushTaskRequest> request,
const ClientCallback<PushTaskReply> &callback) {
return Status::NotImplemented("");
}
/// Similar to PushActorTask, but sets no ordering constraint. This is used to
/// push non-actor tasks directly to a worker.
virtual ray::Status PushNormalTask(std::unique_ptr<PushTaskRequest> request,
const ClientCallback<PushTaskReply> &callback) {
return Status::NotImplemented("");
}
/// Notify a wait has completed for direct actor call arguments.
///
/// \param[in] request The request message.
/// \param[in] callback The callback function that handles reply.
/// \return if the rpc call succeeds
virtual ray::Status DirectActorCallArgWaitComplete(
const DirectActorCallArgWaitCompleteRequest &request,
const ClientCallback<DirectActorCallArgWaitCompleteReply> &callback) {
return Status::NotImplemented("");
}
/// Grants a worker to the client.
///
/// \param[in] request The request message.
/// \param[in] callback The callback function that handles reply.
/// \return if the rpc call succeeds
virtual ray::Status WorkerLeaseGranted(
const WorkerLeaseGrantedRequest &request,
const ClientCallback<WorkerLeaseGrantedReply> &callback) {
return Status::NotImplemented("");
}
};
/// Client used for communicating with a remote worker server.
class WorkerTaskClient : public std::enable_shared_from_this<WorkerTaskClient> {
class CoreWorkerClient : public std::enable_shared_from_this<CoreWorkerClient>,
public CoreWorkerClientInterface {
public:
/// Constructor.
///
/// \param[in] address Address of the worker server.
/// \param[in] port Port of the worker server.
/// \param[in] client_call_manager The `ClientCallManager` used for managing requests.
WorkerTaskClient(const std::string &address, const int port,
CoreWorkerClient(const std::string &address, const int port,
ClientCallManager &client_call_manager)
: client_call_manager_(client_call_manager) {
std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(
address + ":" + std::to_string(port), grpc::InsecureChannelCredentials());
stub_ = WorkerService::NewStub(channel);
stub_ = CoreWorkerService::NewStub(channel);
};
/// Assign a task to the work.
///
/// \param[in] request The request message.
/// \param[in] callback The callback function that handles reply.
/// \return if the rpc call succeeds
ray::Status AssignTask(const AssignTaskRequest &request,
const ClientCallback<AssignTaskReply> &callback) {
auto call =
client_call_manager_
.CreateCall<WorkerService, AssignTaskRequest, AssignTaskReply>(
*stub_, &WorkerService::Stub::PrepareAsyncAssignTask, request, callback);
const ClientCallback<AssignTaskReply> &callback) override {
auto call = client_call_manager_
.CreateCall<CoreWorkerService, AssignTaskRequest, AssignTaskReply>(
*stub_, &CoreWorkerService::Stub::PrepareAsyncAssignTask, request,
callback);
return call->GetStatus();
}
/// Push a task.
///
/// \param[in] request The request message.
/// \param[in] callback The callback function that handles reply.
/// \return if the rpc call succeeds
ray::Status DirectActorAssignTask(
std::unique_ptr<DirectActorAssignTaskRequest> request,
const ClientCallback<DirectActorAssignTaskReply> &callback) {
ray::Status PushActorTask(std::unique_ptr<PushTaskRequest> request,
const ClientCallback<PushTaskReply> &callback) override {
request->set_sequence_number(request->task_spec().actor_task_spec().actor_counter());
{
std::lock_guard<std::mutex> lock(mutex_);
@@ -85,20 +128,36 @@ class WorkerTaskClient : public std::enable_shared_from_this<WorkerTaskClient> {
return ray::Status::OK();
}
/// Notify a wait has completed for direct actor call arguments.
///
/// \param[in] request The request message.
/// \param[in] callback The callback function that handles reply.
/// \return if the rpc call succeeds
ray::Status PushNormalTask(std::unique_ptr<PushTaskRequest> request,
const ClientCallback<PushTaskReply> &callback) override {
request->set_sequence_number(-1);
request->set_client_processed_up_to(-1);
auto call = client_call_manager_
.CreateCall<CoreWorkerService, PushTaskRequest, PushTaskReply>(
*stub_, &CoreWorkerService::Stub::PrepareAsyncPushTask, *request,
callback);
return call->GetStatus();
}
ray::Status DirectActorCallArgWaitComplete(
const DirectActorCallArgWaitCompleteRequest &request,
const ClientCallback<DirectActorCallArgWaitCompleteReply> &callback) {
const ClientCallback<DirectActorCallArgWaitCompleteReply> &callback) override {
auto call = client_call_manager_.CreateCall<CoreWorkerService,
DirectActorCallArgWaitCompleteRequest,
DirectActorCallArgWaitCompleteReply>(
*stub_, &CoreWorkerService::Stub::PrepareAsyncDirectActorCallArgWaitComplete,
request, callback);
return call->GetStatus();
}
ray::Status WorkerLeaseGranted(
const WorkerLeaseGrantedRequest &request,
const ClientCallback<WorkerLeaseGrantedReply> &callback) override {
auto call =
client_call_manager_
.CreateCall<WorkerService, DirectActorCallArgWaitCompleteRequest,
DirectActorCallArgWaitCompleteReply>(
*stub_, &WorkerService::Stub::PrepareAsyncDirectActorCallArgWaitComplete,
request, callback);
client_call_manager_.CreateCall<CoreWorkerService, WorkerLeaseGrantedRequest,
WorkerLeaseGrantedReply>(
*stub_, &CoreWorkerService::Stub::PrepareAsyncWorkerLeaseGranted, request,
callback);
return call->GetStatus();
}
@@ -122,11 +181,10 @@ class WorkerTaskClient : public std::enable_shared_from_this<WorkerTaskClient> {
request->set_client_processed_up_to(max_finished_seq_no_);
rpc_bytes_in_flight_ += task_size;
client_call_manager_.CreateCall<WorkerService, DirectActorAssignTaskRequest,
DirectActorAssignTaskReply>(
*stub_, &WorkerService::Stub::PrepareAsyncDirectActorAssignTask, *request,
[this, this_ptr, seq_no, task_size, callback](
Status status, const rpc::DirectActorAssignTaskReply &reply) {
client_call_manager_.CreateCall<CoreWorkerService, PushTaskRequest, PushTaskReply>(
*stub_, &CoreWorkerService::Stub::PrepareAsyncPushTask, *request,
[this, this_ptr, seq_no, task_size, callback](Status status,
const rpc::PushTaskReply &reply) {
{
std::lock_guard<std::mutex> lock(mutex_);
if (seq_no > max_finished_seq_no_) {
@@ -150,14 +208,13 @@ class WorkerTaskClient : public std::enable_shared_from_this<WorkerTaskClient> {
std::mutex mutex_;
/// The gRPC-generated stub.
std::unique_ptr<WorkerService::Stub> stub_;
std::unique_ptr<CoreWorkerService::Stub> stub_;
/// The `ClientCallManager` used for managing requests.
ClientCallManager &client_call_manager_;
/// Queue of requests to send.
std::deque<std::pair<std::unique_ptr<DirectActorAssignTaskRequest>,
ClientCallback<DirectActorAssignTaskReply>>>
std::deque<std::pair<std::unique_ptr<PushTaskRequest>, ClientCallback<PushTaskReply>>>
send_queue_ GUARDED_BY(mutex_);
/// The number of bytes currently in flight.
@@ -174,4 +231,4 @@ class WorkerTaskClient : public std::enable_shared_from_this<WorkerTaskClient> {
} // namespace rpc
} // namespace ray
#endif // RAY_RPC_WORKER_CLIENT_H
#endif // RAY_RPC_CORE_WORKER_CLIENT_H
@@ -1,23 +1,23 @@
#include "ray/rpc/worker/worker_server.h"
#include "ray/rpc/worker/core_worker_server.h"
#include "ray/core_worker/core_worker.h"
namespace ray {
namespace rpc {
#define RAY_CORE_WORKER_RPC_HANDLER(HANDLER, CONCURRENCY) \
std::unique_ptr<ServerCallFactory> HANDLER##_call_factory( \
new ServerCallFactoryImpl<WorkerService, CoreWorker, HANDLER##Request, \
HANDLER##Reply>( \
service_, &WorkerService::AsyncService::Request##HANDLER, core_worker_, \
&CoreWorker::Handle##HANDLER, cq, main_service_)); \
server_call_factories_and_concurrencies->emplace_back( \
#define RAY_CORE_WORKER_RPC_HANDLER(HANDLER, CONCURRENCY) \
std::unique_ptr<ServerCallFactory> HANDLER##_call_factory( \
new ServerCallFactoryImpl<CoreWorkerService, CoreWorker, HANDLER##Request, \
HANDLER##Reply>( \
service_, &CoreWorkerService::AsyncService::Request##HANDLER, core_worker_, \
&CoreWorker::Handle##HANDLER, cq, main_service_)); \
server_call_factories_and_concurrencies->emplace_back( \
std::move(HANDLER##_call_factory), CONCURRENCY);
WorkerGrpcService::WorkerGrpcService(boost::asio::io_service &main_service,
CoreWorker &core_worker)
CoreWorkerGrpcService::CoreWorkerGrpcService(boost::asio::io_service &main_service,
CoreWorker &core_worker)
: GrpcService(main_service), core_worker_(core_worker){};
void WorkerGrpcService::InitServerCallFactories(
void CoreWorkerGrpcService::InitServerCallFactories(
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
*server_call_factories_and_concurrencies) {
@@ -1,5 +1,5 @@
#ifndef RAY_RPC_WORKER_SERVER_H
#define RAY_RPC_WORKER_SERVER_H
#ifndef RAY_RPC_CORE_WORKER_SERVER_H
#define RAY_RPC_CORE_WORKER_SERVER_H
#include "ray/rpc/grpc_server.h"
#include "ray/rpc/server_call.h"
@@ -13,14 +13,14 @@ class CoreWorker;
namespace rpc {
/// The `GrpcServer` for `WorkerService`.
class WorkerGrpcService : public GrpcService {
/// The `GrpcServer` for `CoreWorkerService`.
class CoreWorkerGrpcService : public GrpcService {
public:
/// Constructor.
///
/// \param[in] main_service See super class.
/// \param[in] handler The service handler that actually handle the requests.
WorkerGrpcService(boost::asio::io_service &main_service, CoreWorker &core_worker);
CoreWorkerGrpcService(boost::asio::io_service &main_service, CoreWorker &core_worker);
protected:
grpc::Service &GetGrpcService() override { return service_; }
@@ -32,7 +32,7 @@ class WorkerGrpcService : public GrpcService {
private:
/// The grpc async service object.
WorkerService::AsyncService service_;
CoreWorkerService::AsyncService service_;
/// The core worker that actually handles the requests.
CoreWorker &core_worker_;
@@ -41,4 +41,4 @@ class WorkerGrpcService : public GrpcService {
} // namespace rpc
} // namespace ray
#endif
#endif // RAY_RPC_CORE_WORKER_SERVER_H
+16
View File
@@ -3,6 +3,8 @@
#include <string>
#include "ray/common/buffer.h"
#include "ray/common/ray_object.h"
#include "ray/util/util.h"
namespace ray {
@@ -40,6 +42,20 @@ inline TaskID RandomTaskId() {
return TaskID::FromBinary(data);
}
std::shared_ptr<Buffer> GenerateRandomBuffer() {
auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count();
std::mt19937 gen(seed);
std::uniform_int_distribution<> dis(1, 10);
std::uniform_int_distribution<> value_dis(1, 255);
std::vector<uint8_t> arg1(dis(gen), value_dis(gen));
return std::make_shared<LocalMemoryBuffer>(arg1.data(), arg1.size(), true);
}
std::shared_ptr<RayObject> GenerateRandomObject() {
return std::shared_ptr<RayObject>(new RayObject(GenerateRandomBuffer(), nullptr));
}
} // namespace ray
#endif // RAY_UTIL_TEST_UTIL_H