mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 21:39:18 +08:00
[gRPC] Use gRPC for inter-node-manager communication (#4968)
This commit is contained in:
@@ -2,3 +2,5 @@
|
||||
build --compilation_mode=opt
|
||||
build --action_env=PATH
|
||||
build --action_env=PYTHON_BIN_PATH
|
||||
# This workaround is needed due to https://github.com/bazelbuild/bazel/issues/4341
|
||||
build --per_file_copt="external/com_github_grpc_grpc/.*@-DGRPC_BAZEL_BUILD"
|
||||
|
||||
+1
-1
@@ -126,7 +126,7 @@ matrix:
|
||||
|
||||
- ./ci/suppress_output ./ci/travis/install-dependencies.sh
|
||||
# This command should be kept in sync with ray/python/README-building-wheels.md.
|
||||
- ./python/build-wheel-macos.sh
|
||||
- ./ci/suppress_output ./python/build-wheel-macos.sh
|
||||
script:
|
||||
- if [ $RAY_CI_MACOS_WHEELS_AFFECTED != "1" ]; then exit; fi
|
||||
|
||||
|
||||
+26
@@ -1,12 +1,37 @@
|
||||
# Bazel build
|
||||
# C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html
|
||||
|
||||
load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library")
|
||||
load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
|
||||
load("@//bazel:ray.bzl", "flatbuffer_py_library")
|
||||
load("@//bazel:cython_library.bzl", "pyx_library")
|
||||
|
||||
COPTS = ["-DRAY_USE_GLOG"]
|
||||
|
||||
# Node manager gRPC lib.
|
||||
grpc_proto_library(
|
||||
name = "node_manager_grpc_lib",
|
||||
srcs = ["src/ray/protobuf/node_manager.proto"],
|
||||
)
|
||||
|
||||
# Node manager server and client.
|
||||
cc_library(
|
||||
name = "node_manager_rpc_lib",
|
||||
srcs = glob([
|
||||
"src/ray/rpc/*.cc",
|
||||
]),
|
||||
hdrs = glob([
|
||||
"src/ray/rpc/*.h",
|
||||
]),
|
||||
copts = COPTS,
|
||||
deps = [
|
||||
":node_manager_grpc_lib",
|
||||
":ray_common",
|
||||
"@boost//:asio",
|
||||
"@com_github_grpc_grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "raylet",
|
||||
srcs = ["src/ray/raylet/main.cc"],
|
||||
@@ -89,6 +114,7 @@ cc_library(
|
||||
":gcs",
|
||||
":gcs_fbs",
|
||||
":node_manager_fbs",
|
||||
":node_manager_rpc_lib",
|
||||
":object_manager",
|
||||
":ray_common",
|
||||
":ray_util",
|
||||
|
||||
@@ -3,6 +3,8 @@ load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
|
||||
load("@com_github_jupp0r_prometheus_cpp//:repositories.bzl", "prometheus_cpp_repositories")
|
||||
load("@com_github_ray_project_ray//bazel:python_configure.bzl", "python_configure")
|
||||
load("@com_github_checkstyle_java//:repo.bzl", "checkstyle_deps")
|
||||
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
|
||||
|
||||
|
||||
def ray_deps_build_all():
|
||||
gen_java_deps()
|
||||
@@ -10,3 +12,5 @@ def ray_deps_build_all():
|
||||
boost_deps()
|
||||
prometheus_cpp_repositories()
|
||||
python_configure(name = "local_config_python")
|
||||
grpc_deps()
|
||||
|
||||
|
||||
@@ -101,3 +101,11 @@ def ray_deps_setup():
|
||||
# `https://github.com/jupp0r/prometheus-cpp/pull/225` getting merged.
|
||||
urls = ["https://github.com/jovany-wang/prometheus-cpp/archive/master.zip"],
|
||||
)
|
||||
|
||||
http_archive(
|
||||
name = "com_github_grpc_grpc",
|
||||
urls = [
|
||||
"https://github.com/grpc/grpc/archive/7741e806a213cba63c96234f16d712a8aa101a49.tar.gz",
|
||||
],
|
||||
strip_prefix = "grpc-7741e806a213cba63c96234f16d712a8aa101a49",
|
||||
)
|
||||
|
||||
+1
-1
@@ -23,7 +23,7 @@ time "$@" >$TMPFILE 2>&1
|
||||
|
||||
CODE=$?
|
||||
if [ $CODE != 0 ]; then
|
||||
cat $TMPFILE
|
||||
tail -n 2000 $TMPFILE
|
||||
echo "FAILED $CODE"
|
||||
kill $WATCHDOG_PID
|
||||
exit $CODE
|
||||
|
||||
@@ -16,7 +16,7 @@ else
|
||||
exit 1
|
||||
fi
|
||||
|
||||
URL="https://github.com/bazelbuild/bazel/releases/download/0.21.0/bazel-0.21.0-installer-${platform}-x86_64.sh"
|
||||
URL="https://github.com/bazelbuild/bazel/releases/download/0.26.1/bazel-0.26.1-installer-${platform}-x86_64.sh"
|
||||
wget -O install.sh $URL
|
||||
chmod +x install.sh
|
||||
./install.sh --user
|
||||
|
||||
+5
-2
@@ -94,11 +94,14 @@ define_java_module(
|
||||
":org_ray_ray_api",
|
||||
":org_ray_ray_runtime",
|
||||
"@plasma//:org_apache_arrow_arrow_plasma",
|
||||
"@maven//:com_google_guava_guava",
|
||||
"@maven//:com_sun_xml_bind_jaxb_core",
|
||||
"@maven//:com_sun_xml_bind_jaxb_impl",
|
||||
"@maven//:commons_io_commons_io",
|
||||
"@maven//:javax_xml_bind_jaxb_api",
|
||||
"@maven//:org_apache_commons_commons_lang3",
|
||||
"@maven//:org_slf4j_slf4j_api",
|
||||
"@maven//:org_testng_testng",
|
||||
"@maven//:com_google_guava_guava",
|
||||
"@maven//:commons_io_commons_io",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -32,11 +32,26 @@
|
||||
<artifactId>guava</artifactId>
|
||||
<version>27.0.1-jre</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.sun.xml.bind</groupId>
|
||||
<artifactId>jaxb-core</artifactId>
|
||||
<version>2.3.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.sun.xml.bind</groupId>
|
||||
<artifactId>jaxb-impl</artifactId>
|
||||
<version>2.3.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>commons-io</groupId>
|
||||
<artifactId>commons-io</artifactId>
|
||||
<version>2.5</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>javax.xml.bind</groupId>
|
||||
<artifactId>jaxb-api</artifactId>
|
||||
<version>2.3.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
|
||||
@@ -6,6 +6,7 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
from ray.tests.conftest import _ray_start_cluster
|
||||
|
||||
num_tasks_submitted = [10**n for n in range(0, 6)]
|
||||
num_tasks_ids = ["{}_tasks".format(i) for i in num_tasks_submitted]
|
||||
@@ -41,3 +42,25 @@ def test_task_submission(benchmark, num_tasks):
|
||||
warmup()
|
||||
benchmark(benchmark_task_submission, num_tasks)
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
def benchmark_task_forward(f, num_tasks):
|
||||
ray.get([f.remote() for _ in range(num_tasks)])
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
@pytest.mark.parametrize(
|
||||
"num_tasks", [10**3, 10**4],
|
||||
ids=[str(num) + "_tasks" for num in [10**3, 10**4]])
|
||||
def test_task_forward(benchmark, num_tasks):
|
||||
with _ray_start_cluster(num_cpus=16, object_store_memory=10**6) as cluster:
|
||||
cluster.add_node(resources={"my_resource": 100})
|
||||
ray.init(redis_address=cluster.redis_address)
|
||||
|
||||
@ray.remote(resources={"my_resource": 0.001})
|
||||
def f():
|
||||
return 1
|
||||
|
||||
# Warm up
|
||||
ray.get([f.remote() for _ in range(100)])
|
||||
benchmark(benchmark_task_forward, f, num_tasks)
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package ray.rpc;
|
||||
|
||||
message ForwardTaskRequest {
|
||||
// The ID of the task to be forwarded.
|
||||
bytes task_id = 1;
|
||||
// The tasks in the uncommitted lineage of the forwarded task. This
|
||||
// should include task_id.
|
||||
// TODO(hchen): Currently, `uncommitted_tasks` are represented as
|
||||
// flatbutters-serialized bytes. This is because the flatbuffers-defined Task data
|
||||
// structure is being used in many places. We should move Task and all related data
|
||||
// strucutres to protobuf.
|
||||
repeated bytes uncommitted_tasks = 2;
|
||||
}
|
||||
|
||||
message ForwardTaskReply {
|
||||
}
|
||||
|
||||
// Service for inter-node-manager communication.
|
||||
service NodeManagerService {
|
||||
// Forward a task and its uncommitted lineage to the remote node manager.
|
||||
rpc ForwardTask(ForwardTaskRequest) returns (ForwardTaskReply);
|
||||
}
|
||||
+87
-144
@@ -99,9 +99,9 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
|
||||
lineage_cache_(gcs_client_->client_table().GetLocalClientId(),
|
||||
gcs_client_->raylet_task_table(), gcs_client_->raylet_task_table(),
|
||||
config.max_lineage_size),
|
||||
remote_clients_(),
|
||||
remote_server_connections_(),
|
||||
actor_registry_() {
|
||||
actor_registry_(),
|
||||
node_manager_server_(config.node_manager_port, io_service, *this),
|
||||
client_call_manager_(io_service) {
|
||||
RAY_CHECK(heartbeat_period_.count() > 0);
|
||||
// Initialize the resource map with own cluster resource configuration.
|
||||
ClientID local_client_id = gcs_client_->client_table().GetLocalClientId();
|
||||
@@ -117,6 +117,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
|
||||
[this](const ObjectID &object_id) { HandleObjectMissing(object_id); }));
|
||||
|
||||
RAY_ARROW_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str()));
|
||||
// Run the node manger rpc server.
|
||||
node_manager_server_.Run();
|
||||
}
|
||||
|
||||
ray::Status NodeManager::RegisterGcs() {
|
||||
@@ -366,66 +368,24 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) {
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO(atumanov): make remote client lookup O(1)
|
||||
if (std::find(remote_clients_.begin(), remote_clients_.end(), client_id) ==
|
||||
remote_clients_.end()) {
|
||||
remote_clients_.push_back(client_id);
|
||||
} else {
|
||||
// NodeManager connection to this client was already established.
|
||||
RAY_LOG(DEBUG) << "Received a new client connection that already exists: "
|
||||
auto entry = remote_node_manager_clients_.find(client_id);
|
||||
if (entry != remote_node_manager_clients_.end()) {
|
||||
RAY_LOG(DEBUG) << "Received notification of a new client that already exists: "
|
||||
<< client_id;
|
||||
return;
|
||||
}
|
||||
|
||||
// Establish a new NodeManager connection to this GCS client.
|
||||
auto status = ConnectRemoteNodeManager(client_id, client_data.node_manager_address,
|
||||
client_data.node_manager_port);
|
||||
if (!status.ok()) {
|
||||
// This is not a fatal error for raylet, but it should not happen.
|
||||
// We need to broadcase this message.
|
||||
std::string type = "raylet_connection_error";
|
||||
std::ostringstream error_message;
|
||||
error_message << "Failed to connect to ray node " << client_id
|
||||
<< " with status: " << status.ToString()
|
||||
<< ". This may be since the node was recently removed.";
|
||||
// We use the nil DriverID to broadcast the message to all drivers.
|
||||
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(
|
||||
DriverID::Nil(), type, error_message.str(), current_time_ms()));
|
||||
return;
|
||||
}
|
||||
// Initialize a rpc client to the new node manager.
|
||||
std::unique_ptr<rpc::NodeManagerClient> client(
|
||||
new rpc::NodeManagerClient(client_data.node_manager_address,
|
||||
client_data.node_manager_port, client_call_manager_));
|
||||
remote_node_manager_clients_.emplace(client_id, std::move(client));
|
||||
|
||||
ResourceSet resources_total(client_data.resources_total_label,
|
||||
client_data.resources_total_capacity);
|
||||
cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total));
|
||||
}
|
||||
|
||||
ray::Status NodeManager::ConnectRemoteNodeManager(const ClientID &client_id,
|
||||
const std::string &client_address,
|
||||
int32_t client_port) {
|
||||
// Establish a new NodeManager connection to this GCS client.
|
||||
RAY_LOG(INFO) << "[ConnectClient] Trying to connect to client " << client_id << " at "
|
||||
<< client_address << ":" << client_port;
|
||||
|
||||
boost::asio::ip::tcp::socket socket(io_service_);
|
||||
RAY_RETURN_NOT_OK(TcpConnect(socket, client_address, client_port));
|
||||
|
||||
// The client is connected, now send a connect message to remote node manager.
|
||||
auto server_conn = TcpServerConnection::Create(std::move(socket));
|
||||
|
||||
// Prepare client connection info buffer
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = protocol::CreateConnectClient(fbb, to_flatbuf(fbb, client_id_));
|
||||
fbb.Finish(message);
|
||||
// Send synchronously.
|
||||
// TODO(swang): Make this a WriteMessageAsync.
|
||||
RAY_RETURN_NOT_OK(server_conn->WriteMessage(
|
||||
static_cast<int64_t>(protocol::MessageType::ConnectClient), fbb.GetSize(),
|
||||
fbb.GetBufferPointer()));
|
||||
|
||||
remote_server_connections_.emplace(client_id, std::move(server_conn));
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
void NodeManager::ClientRemoved(const ClientTableDataT &client_data) {
|
||||
// TODO(swang): If we receive a notification for our own death, clean up and
|
||||
// exit immediately.
|
||||
@@ -440,17 +400,13 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) {
|
||||
// check that it is actually removed, or log a warning otherwise, but that may
|
||||
// not be necessary.
|
||||
|
||||
// Remove the client from the list of remote clients.
|
||||
std::remove(remote_clients_.begin(), remote_clients_.end(), client_id);
|
||||
|
||||
// Remove the client from the resource map.
|
||||
cluster_resource_map_.erase(client_id);
|
||||
|
||||
// Remove the remote server connection.
|
||||
const auto connection_entry = remote_server_connections_.find(client_id);
|
||||
if (connection_entry != remote_server_connections_.end()) {
|
||||
connection_entry->second->Close();
|
||||
remote_server_connections_.erase(connection_entry);
|
||||
// Remove the node manager client.
|
||||
const auto client_entry = remote_node_manager_clients_.find(client_id);
|
||||
if (client_entry != remote_node_manager_clients_.end()) {
|
||||
remote_node_manager_clients_.erase(client_entry);
|
||||
} else {
|
||||
RAY_LOG(WARNING) << "Received ClientRemoved callback for an unknown client "
|
||||
<< client_id << ".";
|
||||
@@ -1241,41 +1197,24 @@ void NodeManager::ProcessNewNodeManager(TcpClientConnection &node_manager_client
|
||||
node_manager_client.ProcessMessages();
|
||||
}
|
||||
|
||||
void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_client,
|
||||
int64_t message_type,
|
||||
const uint8_t *message_data) {
|
||||
const auto message_type_value = static_cast<protocol::MessageType>(message_type);
|
||||
RAY_LOG(DEBUG) << "[NodeManager] Message "
|
||||
<< protocol::EnumNameMessageType(message_type_value) << "("
|
||||
<< message_type << ") from node manager";
|
||||
switch (message_type_value) {
|
||||
case protocol::MessageType::ConnectClient: {
|
||||
auto message = flatbuffers::GetRoot<protocol::ConnectClient>(message_data);
|
||||
auto client_id = from_flatbuf<ClientID>(*message->client_id());
|
||||
node_manager_client.SetClientID(client_id);
|
||||
} break;
|
||||
case protocol::MessageType::ForwardTaskRequest: {
|
||||
auto message = flatbuffers::GetRoot<protocol::ForwardTaskRequest>(message_data);
|
||||
TaskID task_id = from_flatbuf<TaskID>(*message->task_id());
|
||||
|
||||
Lineage uncommitted_lineage(*message);
|
||||
const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData();
|
||||
RAY_LOG(DEBUG) << "Received forwarded task " << task.GetTaskSpecification().TaskId()
|
||||
<< " on node " << gcs_client_->client_table().GetLocalClientId()
|
||||
<< " spillback=" << task.GetTaskExecutionSpec().NumForwards();
|
||||
SubmitTask(task, uncommitted_lineage, /* forwarded = */ true);
|
||||
} break;
|
||||
case protocol::MessageType::DisconnectClient: {
|
||||
// TODO(rkn): We need to do some cleanup here.
|
||||
RAY_LOG(DEBUG) << "Received disconnect message from remote node manager. "
|
||||
<< "We need to do some cleanup here.";
|
||||
// Do not process any more messages from this node manager.
|
||||
return;
|
||||
} break;
|
||||
default:
|
||||
RAY_LOG(FATAL) << "Received unexpected message type " << message_type;
|
||||
void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request,
|
||||
rpc::ForwardTaskReply *reply,
|
||||
rpc::RequestDoneCallback done_callback) {
|
||||
// Get the forwarded task and its uncommitted lineage from the request.
|
||||
TaskID task_id = TaskID::FromBinary(request.task_id());
|
||||
Lineage uncommitted_lineage;
|
||||
for (int i = 0; i < request.uncommitted_tasks_size(); i++) {
|
||||
const std::string &task_message = request.uncommitted_tasks(i);
|
||||
const Task task(*flatbuffers::GetRoot<protocol::Task>(
|
||||
reinterpret_cast<const uint8_t *>(task_message.data())));
|
||||
RAY_CHECK(uncommitted_lineage.SetEntry(std::move(task), GcsStatus::UNCOMMITTED));
|
||||
}
|
||||
node_manager_client.ProcessMessages();
|
||||
const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData();
|
||||
RAY_LOG(DEBUG) << "Received forwarded task " << task.GetTaskSpecification().TaskId()
|
||||
<< " on node " << gcs_client_->client_table().GetLocalClientId()
|
||||
<< " spillback=" << task.GetTaskExecutionSpec().NumForwards();
|
||||
SubmitTask(task, uncommitted_lineage, /* forwarded = */ true);
|
||||
done_callback(Status::OK());
|
||||
}
|
||||
|
||||
void NodeManager::ProcessSetResourceRequest(
|
||||
@@ -2253,6 +2192,16 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task,
|
||||
void NodeManager::ForwardTask(
|
||||
const Task &task, const ClientID &node_id,
|
||||
const std::function<void(const ray::Status &, const Task &)> &on_error) {
|
||||
// Lookup node manager client for this node_id and use it to send the request.
|
||||
auto client_entry = remote_node_manager_clients_.find(node_id);
|
||||
if (client_entry == remote_node_manager_clients_.end()) {
|
||||
// TODO(atumanov): caller must handle failure to ensure tasks are not lost.
|
||||
RAY_LOG(INFO) << "No node manager client found for GCS client id " << node_id;
|
||||
on_error(ray::Status::IOError("Node manager client not found"), task);
|
||||
return;
|
||||
}
|
||||
auto &client = client_entry->second;
|
||||
|
||||
const auto &spec = task.GetTaskSpecification();
|
||||
auto task_id = spec.TaskId();
|
||||
|
||||
@@ -2272,68 +2221,61 @@ void NodeManager::ForwardTask(
|
||||
// Increment forward count for the forwarded task.
|
||||
lineage_cache_entry_task.IncrementNumForwards();
|
||||
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto request = uncommitted_lineage.ToFlatbuffer(fbb, task_id);
|
||||
fbb.Finish(request);
|
||||
|
||||
RAY_LOG(DEBUG) << "Forwarding task " << task_id << " from "
|
||||
<< gcs_client_->client_table().GetLocalClientId() << " to " << node_id
|
||||
<< " spillback="
|
||||
<< lineage_cache_entry_task.GetTaskExecutionSpec().NumForwards();
|
||||
|
||||
// Lookup remote server connection for this node_id and use it to send the request.
|
||||
auto it = remote_server_connections_.find(node_id);
|
||||
if (it == remote_server_connections_.end()) {
|
||||
// TODO(atumanov): caller must handle failure to ensure tasks are not lost.
|
||||
RAY_LOG(INFO) << "No NodeManager connection found for GCS client id " << node_id;
|
||||
on_error(ray::Status::IOError("NodeManager connection not found"), task);
|
||||
return;
|
||||
// Prepare the request message.
|
||||
rpc::ForwardTaskRequest request;
|
||||
request.set_task_id(task_id.Binary());
|
||||
for (auto &entry : uncommitted_lineage.GetEntries()) {
|
||||
request.add_uncommitted_tasks(entry.second.TaskData().Serialize());
|
||||
}
|
||||
auto &server_conn = it->second;
|
||||
|
||||
// Move the FORWARDING task to the SWAP queue so that we remember that we
|
||||
// have it queued locally. Once the ForwardTaskRequest has been sent, the
|
||||
// task will get re-queued, depending on whether the message succeeded or
|
||||
// not.
|
||||
local_queues_.QueueTasks({task}, TaskState::SWAP);
|
||||
server_conn->WriteMessageAsync(
|
||||
static_cast<int64_t>(protocol::MessageType::ForwardTaskRequest), fbb.GetSize(),
|
||||
fbb.GetBufferPointer(), [this, on_error, task_id, node_id](ray::Status status) {
|
||||
// Remove the FORWARDING task from the SWAP queue.
|
||||
TaskState state;
|
||||
const auto task = local_queues_.RemoveTask(task_id, &state);
|
||||
RAY_CHECK(state == TaskState::SWAP);
|
||||
client->ForwardTask(request, [this, on_error, task_id, node_id](
|
||||
Status status, const rpc::ForwardTaskReply &reply) {
|
||||
// Remove the FORWARDING task from the SWAP queue.
|
||||
TaskState state;
|
||||
const auto task = local_queues_.RemoveTask(task_id, &state);
|
||||
RAY_CHECK(state == TaskState::SWAP);
|
||||
|
||||
if (status.ok()) {
|
||||
const auto &spec = task.GetTaskSpecification();
|
||||
// Mark as forwarded so that the task and its lineage are not
|
||||
// re-forwarded in the future to the receiving node.
|
||||
lineage_cache_.MarkTaskAsForwarded(task_id, node_id);
|
||||
if (status.ok()) {
|
||||
const auto &spec = task.GetTaskSpecification();
|
||||
// Mark as forwarded so that the task and its lineage are not
|
||||
// re-forwarded in the future to the receiving node.
|
||||
lineage_cache_.MarkTaskAsForwarded(task_id, node_id);
|
||||
|
||||
// Notify the task dependency manager that we are no longer responsible
|
||||
// for executing this task.
|
||||
task_dependency_manager_.TaskCanceled(task_id);
|
||||
// Preemptively push any local arguments to the receiving node. For now, we
|
||||
// only do this with actor tasks, since actor tasks must be executed by a
|
||||
// specific process and therefore have affinity to the receiving node.
|
||||
if (spec.IsActorTask()) {
|
||||
// Iterate through the object's arguments. NOTE(swang): We do not include
|
||||
// the execution dependencies here since those cannot be transferred
|
||||
// between nodes.
|
||||
for (int i = 0; i < spec.NumArgs(); ++i) {
|
||||
int count = spec.ArgIdCount(i);
|
||||
for (int j = 0; j < count; j++) {
|
||||
ObjectID argument_id = spec.ArgId(i, j);
|
||||
// If the argument is local, then push it to the receiving node.
|
||||
if (task_dependency_manager_.CheckObjectLocal(argument_id)) {
|
||||
object_manager_.Push(argument_id, node_id);
|
||||
}
|
||||
}
|
||||
// Notify the task dependency manager that we are no longer responsible
|
||||
// for executing this task.
|
||||
task_dependency_manager_.TaskCanceled(task_id);
|
||||
// Preemptively push any local arguments to the receiving node. For now, we
|
||||
// only do this with actor tasks, since actor tasks must be executed by a
|
||||
// specific process and therefore have affinity to the receiving node.
|
||||
if (spec.IsActorTask()) {
|
||||
// Iterate through the object's arguments. NOTE(swang): We do not include
|
||||
// the execution dependencies here since those cannot be transferred
|
||||
// between nodes.
|
||||
for (int i = 0; i < spec.NumArgs(); ++i) {
|
||||
int count = spec.ArgIdCount(i);
|
||||
for (int j = 0; j < count; j++) {
|
||||
ObjectID argument_id = spec.ArgId(i, j);
|
||||
// If the argument is local, then push it to the receiving node.
|
||||
if (task_dependency_manager_.CheckObjectLocal(argument_id)) {
|
||||
object_manager_.Push(argument_id, node_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
on_error(status, task);
|
||||
}
|
||||
});
|
||||
}
|
||||
} else {
|
||||
on_error(status, task);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void NodeManager::DumpDebugState() const {
|
||||
@@ -2368,10 +2310,11 @@ std::string NodeManager::DebugString() const {
|
||||
result << "\n- num dead actors: " << statistical_data.dead_actors;
|
||||
result << "\n- max num handles: " << statistical_data.max_num_handles;
|
||||
|
||||
result << "\nRemoteConnections:";
|
||||
for (auto &pair : remote_server_connections_) {
|
||||
result << "\n" << pair.first.Hex() << ": " << pair.second->DebugString();
|
||||
result << "\nRemote node manager clients: ";
|
||||
for (const auto &entry : remote_node_manager_clients_) {
|
||||
result << "\n" << entry.first;
|
||||
}
|
||||
|
||||
result << "\nDebugString() time ms: " << (current_time_ms() - now_ms);
|
||||
return result.str();
|
||||
}
|
||||
|
||||
@@ -4,6 +4,9 @@
|
||||
#include <boost/asio/steady_timer.hpp>
|
||||
|
||||
// clang-format off
|
||||
#include "ray/rpc/client_call.h"
|
||||
#include "ray/rpc/node_manager_server.h"
|
||||
#include "ray/rpc/node_manager_client.h"
|
||||
#include "ray/raylet/task.h"
|
||||
#include "ray/object_manager/object_manager.h"
|
||||
#include "ray/common/client_connection.h"
|
||||
@@ -52,7 +55,7 @@ struct NodeManagerConfig {
|
||||
std::string session_dir;
|
||||
};
|
||||
|
||||
class NodeManager {
|
||||
class NodeManager : public rpc::NodeManagerServiceHandler {
|
||||
public:
|
||||
/// Create a node manager.
|
||||
///
|
||||
@@ -86,15 +89,6 @@ class NodeManager {
|
||||
/// \return Void.
|
||||
void ProcessNewNodeManager(TcpClientConnection &node_manager_client);
|
||||
|
||||
/// Handle a message from a remote node manager.
|
||||
///
|
||||
/// \param node_manager_client The connection to the remote node manager.
|
||||
/// \param message_type The type of the message.
|
||||
/// \param message The message contents.
|
||||
/// \return Void.
|
||||
void ProcessNodeManagerMessage(TcpClientConnection &node_manager_client,
|
||||
int64_t message_type, const uint8_t *message);
|
||||
|
||||
/// Subscribe to the relevant GCS tables and set up handlers.
|
||||
///
|
||||
/// \return Status indicating whether this was done successfully or not.
|
||||
@@ -108,6 +102,9 @@ class NodeManager {
|
||||
/// Record metrics.
|
||||
void RecordMetrics() const;
|
||||
|
||||
/// Get the port of the node manager rpc server.
|
||||
int GetServerPort() const { return node_manager_server_.GetPort(); }
|
||||
|
||||
private:
|
||||
/// Methods for handling clients.
|
||||
|
||||
@@ -450,15 +447,10 @@ class NodeManager {
|
||||
void HandleDisconnectedActor(const ActorID &actor_id, bool was_local,
|
||||
bool intentional_disconnect);
|
||||
|
||||
/// connect to a remote node manager.
|
||||
///
|
||||
/// \param client_id The client ID for the remote node manager.
|
||||
/// \param client_address The IP address for the remote node manager.
|
||||
/// \param client_port The listening port for the remote node manager.
|
||||
/// \return True if the connect succeeds.
|
||||
ray::Status ConnectRemoteNodeManager(const ClientID &client_id,
|
||||
const std::string &client_address,
|
||||
int32_t client_port);
|
||||
/// Handle a `ForwardTask` request.
|
||||
void HandleForwardTask(const rpc::ForwardTaskRequest &request,
|
||||
rpc::ForwardTaskReply *reply,
|
||||
rpc::RequestDoneCallback done_callback) override;
|
||||
|
||||
// GCS client ID for this node.
|
||||
ClientID client_id_;
|
||||
@@ -505,9 +497,6 @@ class NodeManager {
|
||||
TaskDependencyManager task_dependency_manager_;
|
||||
/// The lineage cache for the GCS object and task tables.
|
||||
LineageCache lineage_cache_;
|
||||
std::vector<ClientID> remote_clients_;
|
||||
std::unordered_map<ClientID, std::shared_ptr<TcpServerConnection>>
|
||||
remote_server_connections_;
|
||||
/// A mapping from actor ID to registration information about that actor
|
||||
/// (including which node manager owns it).
|
||||
std::unordered_map<ActorID, ActorRegistration> actor_registry_;
|
||||
@@ -515,6 +504,16 @@ class NodeManager {
|
||||
/// This map stores actor ID to the ID of the checkpoint that will be used to
|
||||
/// restore the actor.
|
||||
std::unordered_map<ActorID, ActorCheckpointID> checkpoint_id_to_restore_;
|
||||
|
||||
/// The RPC server.
|
||||
rpc::NodeManagerServer node_manager_server_;
|
||||
|
||||
/// The `ClientCallManager` object that is shared by all `NodeManagerClient`s.
|
||||
rpc::ClientCallManager client_call_manager_;
|
||||
|
||||
/// Map from node ids to clients of the remote node managers.
|
||||
std::unordered_map<ClientID, std::unique_ptr<rpc::NodeManagerClient>>
|
||||
remote_node_manager_clients_;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
@@ -61,15 +61,10 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_
|
||||
main_service,
|
||||
boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(),
|
||||
object_manager_config.object_manager_port)),
|
||||
object_manager_socket_(main_service),
|
||||
node_manager_acceptor_(main_service, boost::asio::ip::tcp::endpoint(
|
||||
boost::asio::ip::tcp::v4(),
|
||||
node_manager_config.node_manager_port)),
|
||||
node_manager_socket_(main_service) {
|
||||
object_manager_socket_(main_service) {
|
||||
// Start listening for clients.
|
||||
DoAccept();
|
||||
DoAcceptObjectManager();
|
||||
DoAcceptNodeManager();
|
||||
|
||||
RAY_CHECK_OK(RegisterGcs(
|
||||
node_ip_address, socket_name_, object_manager_config.store_socket_name,
|
||||
@@ -100,7 +95,7 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address,
|
||||
client_info.raylet_socket_name = raylet_socket_name;
|
||||
client_info.object_store_socket_name = object_store_socket_name;
|
||||
client_info.object_manager_port = object_manager_acceptor_.local_endpoint().port();
|
||||
client_info.node_manager_port = node_manager_acceptor_.local_endpoint().port();
|
||||
client_info.node_manager_port = node_manager_.GetServerPort();
|
||||
// Add resource information.
|
||||
for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) {
|
||||
client_info.resources_total_label.push_back(resource_pair.first);
|
||||
@@ -120,33 +115,6 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void Raylet::DoAcceptNodeManager() {
|
||||
node_manager_acceptor_.async_accept(node_manager_socket_,
|
||||
boost::bind(&Raylet::HandleAcceptNodeManager, this,
|
||||
boost::asio::placeholders::error));
|
||||
}
|
||||
|
||||
void Raylet::HandleAcceptNodeManager(const boost::system::error_code &error) {
|
||||
if (!error) {
|
||||
ClientHandler<boost::asio::ip::tcp> client_handler =
|
||||
[this](TcpClientConnection &client) {
|
||||
node_manager_.ProcessNewNodeManager(client);
|
||||
};
|
||||
MessageHandler<boost::asio::ip::tcp> message_handler =
|
||||
[this](std::shared_ptr<TcpClientConnection> client, int64_t message_type,
|
||||
const uint8_t *message) {
|
||||
node_manager_.ProcessNodeManagerMessage(*client, message_type, message);
|
||||
};
|
||||
// Accept a new TCP client and dispatch it to the node manager.
|
||||
auto new_connection = TcpClientConnection::Create(
|
||||
client_handler, message_handler, std::move(node_manager_socket_), "node manager",
|
||||
node_manager_message_enum,
|
||||
static_cast<int64_t>(protocol::MessageType::DisconnectClient));
|
||||
}
|
||||
// We're ready to accept another client.
|
||||
DoAcceptNodeManager();
|
||||
}
|
||||
|
||||
void Raylet::DoAcceptObjectManager() {
|
||||
object_manager_acceptor_.async_accept(
|
||||
object_manager_socket_, boost::bind(&Raylet::HandleAcceptObjectManager, this,
|
||||
|
||||
@@ -63,8 +63,6 @@ class Raylet {
|
||||
void DoAcceptObjectManager();
|
||||
/// Handle an accepted tcp client connection.
|
||||
void HandleAcceptObjectManager(const boost::system::error_code &error);
|
||||
void DoAcceptNodeManager();
|
||||
void HandleAcceptNodeManager(const boost::system::error_code &error);
|
||||
|
||||
friend class TestObjectManagerIntegration;
|
||||
|
||||
@@ -88,10 +86,6 @@ class Raylet {
|
||||
boost::asio::ip::tcp::acceptor object_manager_acceptor_;
|
||||
/// The socket to listen on for new object manager tcp clients.
|
||||
boost::asio::ip::tcp::socket object_manager_socket_;
|
||||
/// An acceptor for new tcp clients.
|
||||
boost::asio::ip::tcp::acceptor node_manager_acceptor_;
|
||||
/// The socket to listen on for new tcp clients.
|
||||
boost::asio::ip::tcp::socket node_manager_socket_;
|
||||
};
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
@@ -46,14 +46,18 @@ void Task::CopyTaskExecutionSpec(const Task &task) {
|
||||
ComputeDependencies();
|
||||
}
|
||||
|
||||
const std::string Task::Serialize() const {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
fbb.Finish(ToFlatbuffer(fbb));
|
||||
return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize());
|
||||
}
|
||||
|
||||
std::string SerializeTaskAsString(const std::vector<ObjectID> *dependencies,
|
||||
const TaskSpecification *task_spec) {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
std::vector<ObjectID> execution_dependencies(*dependencies);
|
||||
TaskExecutionSpecification execution_spec(std::move(execution_dependencies));
|
||||
Task task(execution_spec, *task_spec);
|
||||
fbb.Finish(task.ToFlatbuffer(fbb));
|
||||
return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize());
|
||||
return task.Serialize();
|
||||
}
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
@@ -84,6 +84,9 @@ class Task {
|
||||
/// \param task Task structure with updated dynamic information.
|
||||
void CopyTaskExecutionSpec(const Task &task);
|
||||
|
||||
/// Serialize this task as a string.
|
||||
const std::string Serialize() const;
|
||||
|
||||
private:
|
||||
void ComputeDependencies();
|
||||
|
||||
|
||||
@@ -0,0 +1,169 @@
|
||||
#ifndef RAY_RPC_CLIENT_CALL_H
|
||||
#define RAY_RPC_CLIENT_CALL_H
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#include <boost/asio.hpp>
|
||||
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/rpc/util.h"
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
/// Represents an outgoing gRPC request.
|
||||
///
|
||||
/// The lifecycle of a `ClientCall` is as follows.
|
||||
///
|
||||
/// When a client submits a new gRPC request, a new `ClientCall` object will be created
|
||||
/// by `ClientCallMangager::CreateCall`. Then the object will be used as the tag of
|
||||
/// `CompletionQueue`.
|
||||
///
|
||||
/// When the reply is received, `ClientCallMangager` will get the address of this object
|
||||
/// via `CompletionQueue`'s tag. And the manager should call `OnReplyReceived` and then
|
||||
/// delete this object.
|
||||
///
|
||||
/// NOTE(hchen): Compared to `ClientCallImpl`, this abstract interface doesn't use
|
||||
/// template. This allows the users (e.g., `ClientCallMangager`) not having to use
|
||||
/// template as well.
|
||||
class ClientCall {
|
||||
public:
|
||||
/// The callback to be called by `ClientCallManager` when the reply of this request is
|
||||
/// received.
|
||||
virtual void OnReplyReceived() = 0;
|
||||
};
|
||||
|
||||
class ClientCallManager;
|
||||
|
||||
/// Reprents the client callback function of a particular rpc method.
|
||||
///
|
||||
/// \tparam Reply Type of the reply message.
|
||||
template <class Reply>
|
||||
using ClientCallback = std::function<void(const Status &status, const Reply &reply)>;
|
||||
|
||||
/// Implementaion of the `ClientCall`. It represents a `ClientCall` for a particular
|
||||
/// RPC method.
|
||||
///
|
||||
/// \tparam Reply Type of the Reply message.
|
||||
template <class Reply>
|
||||
class ClientCallImpl : public ClientCall {
|
||||
public:
|
||||
void OnReplyReceived() override { callback_(GrpcStatusToRayStatus(status_), reply_); }
|
||||
|
||||
private:
|
||||
/// Constructor.
|
||||
///
|
||||
/// \param[in] callback The callback function to handle the reply.
|
||||
ClientCallImpl(const ClientCallback<Reply> &callback) : callback_(callback) {}
|
||||
|
||||
/// The reply message.
|
||||
Reply reply_;
|
||||
|
||||
/// The callback function to handle the reply.
|
||||
ClientCallback<Reply> callback_;
|
||||
|
||||
/// The response reader.
|
||||
std::unique_ptr<grpc::ClientAsyncResponseReader<Reply>> response_reader_;
|
||||
|
||||
/// gRPC status of this request.
|
||||
grpc::Status status_;
|
||||
|
||||
/// Context for the client. It could be used to convey extra information to
|
||||
/// the server and/or tweak certain RPC behaviors.
|
||||
grpc::ClientContext context_;
|
||||
|
||||
friend class ClientCallManager;
|
||||
};
|
||||
|
||||
/// Peprents the generic signature of a `FooService::Stub::PrepareAsyncBar`
|
||||
/// function, where `Foo` is the service name and `Bar` is the rpc method name.
|
||||
///
|
||||
/// \tparam GrpcService Type of the gRPC-generated service class.
|
||||
/// \tparam Request Type of the request message.
|
||||
/// \tparam Reply Type of the reply message.
|
||||
template <class GrpcService, class Request, class Reply>
|
||||
using PrepareAsyncFunction = std::unique_ptr<grpc::ClientAsyncResponseReader<Reply>> (
|
||||
GrpcService::Stub::*)(grpc::ClientContext *context, const Request &request,
|
||||
grpc::CompletionQueue *cq);
|
||||
|
||||
/// `ClientCallManager` is used to manage outgoing gRPC requests and the lifecycles of
|
||||
/// `ClientCall` objects.
|
||||
///
|
||||
/// It maintains a thread that keeps polling events from `CompletionQueue`, and post
|
||||
/// the callback function to the main event loop when a reply is received.
|
||||
///
|
||||
/// Mutiple clients can share one `ClientCallManager`.
|
||||
class ClientCallManager {
|
||||
public:
|
||||
/// Constructor.
|
||||
///
|
||||
/// \param[in] main_service The main event loop, to which the callback functions will be
|
||||
/// posted.
|
||||
ClientCallManager(boost::asio::io_service &main_service) : main_service_(main_service) {
|
||||
// Start the polling thread.
|
||||
std::thread polling_thread(&ClientCallManager::PollEventsFromCompletionQueue, this);
|
||||
polling_thread.detach();
|
||||
}
|
||||
|
||||
~ClientCallManager() { cq_.Shutdown(); }
|
||||
|
||||
/// Create a new `ClientCall` and send request.
|
||||
///
|
||||
/// \param[in] stub The gRPC-generated stub.
|
||||
/// \param[in] prepare_async_function Pointer to the gRPC-generated
|
||||
/// `FooService::Stub::PrepareAsyncBar` function.
|
||||
/// \param[in] request The request message.
|
||||
/// \param[in] callback The callback function that handles reply.
|
||||
///
|
||||
/// \tparam GrpcService Type of the gRPC-generated service class.
|
||||
/// \tparam Request Type of the request message.
|
||||
/// \tparam Reply Type of the reply message.
|
||||
template <class GrpcService, class Request, class Reply>
|
||||
ClientCall *CreateCall(
|
||||
typename GrpcService::Stub &stub,
|
||||
const PrepareAsyncFunction<GrpcService, Request, Reply> prepare_async_function,
|
||||
const Request &request, const ClientCallback<Reply> &callback) {
|
||||
// Create a new `ClientCall` object. This object will eventuall be deleted in the
|
||||
// `ClientCallManager::PollEventsFromCompletionQueue` when reply is received.
|
||||
auto call = new ClientCallImpl<Reply>(callback);
|
||||
// Send request.
|
||||
call->response_reader_ =
|
||||
(stub.*prepare_async_function)(&call->context_, request, &cq_);
|
||||
call->response_reader_->StartCall();
|
||||
call->response_reader_->Finish(&call->reply_, &call->status_, (void *)call);
|
||||
return call;
|
||||
}
|
||||
|
||||
private:
|
||||
/// This function runs in a background thread. It keeps polling events from the
|
||||
/// `CompletionQueue`, and dispaches the event to the callbacks via the `ClientCall`
|
||||
/// objects.
|
||||
void PollEventsFromCompletionQueue() {
|
||||
void *got_tag;
|
||||
bool ok = false;
|
||||
// Keep reading events from the `CompletionQueue` until it's shutdown.
|
||||
while (cq_.Next(&got_tag, &ok)) {
|
||||
ClientCall *call = reinterpret_cast<ClientCall *>(got_tag);
|
||||
if (ok) {
|
||||
// Post the callback to the main event loop.
|
||||
main_service_.post([call]() {
|
||||
call->OnReplyReceived();
|
||||
// The call is finished, we can delete the `ClientCall` object now.
|
||||
delete call;
|
||||
});
|
||||
} else {
|
||||
delete call;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The main event loop, to which the callback functions will be posted.
|
||||
boost::asio::io_service &main_service_;
|
||||
|
||||
/// The gRPC `CompletionQueue` object used to poll events.
|
||||
grpc::CompletionQueue cq_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,70 @@
|
||||
#include "ray/rpc/grpc_server.h"
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
void GrpcServer::Run() {
|
||||
std::string server_address("0.0.0.0:" + std::to_string(port_));
|
||||
|
||||
grpc::ServerBuilder builder;
|
||||
// TODO(hchen): Add options for authentication.
|
||||
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_);
|
||||
// Allow subclasses to register concrete services.
|
||||
RegisterServices(builder);
|
||||
// Get hold of the completion queue used for the asynchronous communication
|
||||
// with the gRPC runtime.
|
||||
cq_ = builder.AddCompletionQueue();
|
||||
// Build and start server.
|
||||
server_ = builder.BuildAndStart();
|
||||
RAY_LOG(DEBUG) << name_ << " server started, listening on port " << port_ << ".";
|
||||
|
||||
// Allow subclasses to initialize the server call factories.
|
||||
InitServerCallFactories(&server_call_factories_and_concurrencies_);
|
||||
for (auto &entry : server_call_factories_and_concurrencies_) {
|
||||
for (int i = 0; i < entry.second; i++) {
|
||||
// Create and request calls from the factory.
|
||||
entry.first->CreateCall();
|
||||
}
|
||||
}
|
||||
// Start a thread that polls incoming requests.
|
||||
std::thread polling_thread(&GrpcServer::PollEventsFromCompletionQueue, this);
|
||||
polling_thread.detach();
|
||||
}
|
||||
|
||||
void GrpcServer::PollEventsFromCompletionQueue() {
|
||||
void *tag;
|
||||
bool ok;
|
||||
// Keep reading events from the `CompletionQueue` until it's shutdown.
|
||||
while (cq_->Next(&tag, &ok)) {
|
||||
ServerCall *server_call = static_cast<ServerCall *>(tag);
|
||||
// `ok == false` indicates that the server has been shut down.
|
||||
// We should delete the call object in this case.
|
||||
bool delete_call = !ok;
|
||||
if (ok) {
|
||||
switch (server_call->GetState()) {
|
||||
case ServerCallState::PENDING:
|
||||
// We've received a new incoming request. Now this call object is used to
|
||||
// track this request. So we need to create another call to handle next
|
||||
// incoming request.
|
||||
server_call->GetFactory().CreateCall();
|
||||
server_call->SetState(ServerCallState::PROCESSING);
|
||||
main_service_.post([server_call] { server_call->HandleRequest(); });
|
||||
break;
|
||||
case ServerCallState::SENDING_REPLY:
|
||||
// The reply has been sent, this call can be deleted now.
|
||||
// This event is triggered by `ServerCallImpl::SendReply`.
|
||||
delete_call = true;
|
||||
break;
|
||||
default:
|
||||
RAY_LOG(FATAL) << "Shouldn't reach here.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (delete_call) {
|
||||
delete server_call;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
@@ -0,0 +1,92 @@
|
||||
#ifndef RAY_RPC_GRPC_SERVER_H
|
||||
#define RAY_RPC_GRPC_SERVER_H
|
||||
|
||||
#include <thread>
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
#include <boost/asio.hpp>
|
||||
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/rpc/server_call.h"
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
/// Base class that represents an abstract gRPC server.
|
||||
///
|
||||
/// A `GrpcServer` listens on a specific port. It owns
|
||||
/// 1) a `ServerCompletionQueue` that is used for polling events from gRPC,
|
||||
/// 2) and a thread that polls events from the `ServerCompletionQueue`.
|
||||
///
|
||||
/// Subclasses can register one or multiple services to a `GrpcServer`, see
|
||||
/// `RegisterServices`. And they should also implement `InitServerCallFactories` to decide
|
||||
/// which kinds of requests this server should accept.
|
||||
class GrpcServer {
|
||||
public:
|
||||
/// Constructor.
|
||||
///
|
||||
/// \param[in] name Name of this server, used for logging and debugging purpose.
|
||||
/// \param[in] port The port to bind this server to. If it's 0, a random available port
|
||||
/// will be chosen.
|
||||
/// \param[in] main_service The main event loop, to which service handler functions
|
||||
/// will be posted.
|
||||
GrpcServer(const std::string &name, const uint32_t port,
|
||||
boost::asio::io_service &main_service)
|
||||
: name_(name), port_(port), main_service_(main_service) {}
|
||||
|
||||
/// Destruct this gRPC server.
|
||||
~GrpcServer() {
|
||||
server_->Shutdown();
|
||||
cq_->Shutdown();
|
||||
}
|
||||
|
||||
/// Initialize and run this server.
|
||||
void Run();
|
||||
|
||||
/// Get the port of this gRPC server.
|
||||
int GetPort() const { return port_; }
|
||||
|
||||
protected:
|
||||
/// Subclasses should implement this method and register one or multiple gRPC services
|
||||
/// to the given `ServerBuilder`.
|
||||
///
|
||||
/// \param[in] builder The `ServerBuilder` instance to register services to.
|
||||
virtual void RegisterServices(grpc::ServerBuilder &builder) = 0;
|
||||
|
||||
/// Subclasses should implement this method to initialize the `ServerCallFactory`
|
||||
/// instances, as well as specify maximum number of concurrent requests that gRPC
|
||||
/// server can "accept" (not "handle"). Each factory will be used to create
|
||||
/// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and
|
||||
/// handle an incoming request.
|
||||
///
|
||||
/// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects,
|
||||
/// and the maximum number of concurrent requests that gRPC server can accept.
|
||||
virtual void InitServerCallFactories(
|
||||
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
|
||||
*server_call_factories_and_concurrencies) = 0;
|
||||
|
||||
/// This function runs in a background thread. It keeps polling events from the
|
||||
/// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances
|
||||
/// via the `ServerCall` objects.
|
||||
void PollEventsFromCompletionQueue();
|
||||
|
||||
/// The main event loop, to which the service handler functions will be posted.
|
||||
boost::asio::io_service &main_service_;
|
||||
/// Name of this server, used for logging and debugging purpose.
|
||||
const std::string name_;
|
||||
/// Port of this server.
|
||||
int port_;
|
||||
/// The `ServerCallFactory` objects, and the maximum number of concurrent requests that
|
||||
/// gRPC server can accept.
|
||||
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
|
||||
server_call_factories_and_concurrencies_;
|
||||
/// The `ServerCompletionQueue` object used for polling events.
|
||||
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
|
||||
/// The `Server` object.
|
||||
std::unique_ptr<grpc::Server> server_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,56 @@
|
||||
#ifndef RAY_RPC_NODE_MANAGER_CLIENT_H
|
||||
#define RAY_RPC_NODE_MANAGER_CLIENT_H
|
||||
|
||||
#include <thread>
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/rpc/client_call.h"
|
||||
#include "ray/util/logging.h"
|
||||
#include "src/ray/protobuf/node_manager.grpc.pb.h"
|
||||
#include "src/ray/protobuf/node_manager.pb.h"
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
/// Client used for communicating with a remote node manager server.
|
||||
class NodeManagerClient {
|
||||
public:
|
||||
/// Constructor.
|
||||
///
|
||||
/// \param[in] address Address of the node manager server.
|
||||
/// \param[in] port Port of the node manager server.
|
||||
/// \param[in] client_call_manager The `ClientCallManager` used for managing requests.
|
||||
NodeManagerClient(const std::string &address, const int port,
|
||||
ClientCallManager &client_call_manager)
|
||||
: client_call_manager_(client_call_manager) {
|
||||
std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(
|
||||
address + ":" + std::to_string(port), grpc::InsecureChannelCredentials());
|
||||
stub_ = NodeManagerService::NewStub(channel);
|
||||
};
|
||||
|
||||
/// Forward a task and its uncommitted lineage.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[in] callback The callback function that handles reply.
|
||||
void ForwardTask(const ForwardTaskRequest &request,
|
||||
const ClientCallback<ForwardTaskReply> &callback) {
|
||||
client_call_manager_
|
||||
.CreateCall<NodeManagerService, ForwardTaskRequest, ForwardTaskReply>(
|
||||
*stub_, &NodeManagerService::Stub::PrepareAsyncForwardTask, request,
|
||||
callback);
|
||||
}
|
||||
|
||||
private:
|
||||
/// The gRPC-generated stub.
|
||||
std::unique_ptr<NodeManagerService::Stub> stub_;
|
||||
|
||||
/// The `ClientCallManager` used for managing requests.
|
||||
ClientCallManager &client_call_manager_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
|
||||
#endif // RAY_RPC_NODE_MANAGER_CLIENT_H
|
||||
@@ -0,0 +1,71 @@
|
||||
#ifndef RAY_RPC_NODE_MANAGER_SERVER_H
|
||||
#define RAY_RPC_NODE_MANAGER_SERVER_H
|
||||
|
||||
#include "ray/rpc/grpc_server.h"
|
||||
#include "ray/rpc/server_call.h"
|
||||
|
||||
#include "src/ray/protobuf/node_manager.grpc.pb.h"
|
||||
#include "src/ray/protobuf/node_manager.pb.h"
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
/// Interface of the `NodeManagerService`, see `src/ray/protobuf/node_manager.proto`.
|
||||
class NodeManagerServiceHandler {
|
||||
public:
|
||||
/// Handle a `ForwardTask` request.
|
||||
/// The implementation can handle this request asynchronously. When hanling is done, the
|
||||
/// `done_callback` should be called.
|
||||
///
|
||||
/// \param[in] request The request message.
|
||||
/// \param[out] reply The reply message.
|
||||
/// \param[in] done_callback The callback to be called when the request is done.
|
||||
virtual void HandleForwardTask(const ForwardTaskRequest &request,
|
||||
ForwardTaskReply *reply,
|
||||
RequestDoneCallback done_callback) = 0;
|
||||
};
|
||||
|
||||
/// The `GrpcServer` for `NodeManagerService`.
|
||||
class NodeManagerServer : public GrpcServer {
|
||||
public:
|
||||
/// Constructor.
|
||||
///
|
||||
/// \param[in] port See super class.
|
||||
/// \param[in] main_service See super class.
|
||||
/// \param[in] handler The service handler that actually handle the requests.
|
||||
NodeManagerServer(const uint32_t port, boost::asio::io_service &main_service,
|
||||
NodeManagerServiceHandler &service_handler)
|
||||
: GrpcServer("NodeManager", port, main_service),
|
||||
service_handler_(service_handler){};
|
||||
|
||||
void RegisterServices(grpc::ServerBuilder &builder) override {
|
||||
/// Register `NodeManagerService`.
|
||||
builder.RegisterService(&service_);
|
||||
}
|
||||
|
||||
void InitServerCallFactories(
|
||||
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
|
||||
*server_call_factories_and_concurrencies) override {
|
||||
// Initialize the factory for `ForwardTask` requests.
|
||||
std::unique_ptr<ServerCallFactory> forward_task_call_factory(
|
||||
new ServerCallFactoryImpl<NodeManagerService, NodeManagerServiceHandler,
|
||||
ForwardTaskRequest, ForwardTaskReply>(
|
||||
service_, &NodeManagerService::AsyncService::RequestForwardTask,
|
||||
service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq_));
|
||||
|
||||
// Set `ForwardTask`'s accept concurrency to 100.
|
||||
server_call_factories_and_concurrencies->emplace_back(
|
||||
std::move(forward_task_call_factory), 100);
|
||||
}
|
||||
|
||||
private:
|
||||
/// The grpc async service object.
|
||||
NodeManagerService::AsyncService service_;
|
||||
/// The service handler that actually handle the requests.
|
||||
NodeManagerServiceHandler &service_handler_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,233 @@
|
||||
#ifndef RAY_RPC_SERVER_CALL_H
|
||||
#define RAY_RPC_SERVER_CALL_H
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
|
||||
#include "ray/common/status.h"
|
||||
#include "ray/rpc/util.h"
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
/// Represents the callback function to be called when a `ServiceHandler` finishes
|
||||
/// handling a request.
|
||||
using RequestDoneCallback = std::function<void(Status)>;
|
||||
|
||||
/// Represents state of a `ServerCall`.
|
||||
enum class ServerCallState {
|
||||
/// The call is created and waiting for an incoming request.
|
||||
PENDING,
|
||||
/// Request is received and being processed.
|
||||
PROCESSING,
|
||||
/// Request processing is done, and reply is being sent to client.
|
||||
SENDING_REPLY
|
||||
};
|
||||
|
||||
class ServerCallFactory;
|
||||
|
||||
/// Reprensents an incoming request of a gRPC server.
|
||||
///
|
||||
/// The lifecycle and state transition of a `ServerCall` is as follows:
|
||||
///
|
||||
/// --(1)--> PENDING --(2)--> PROCESSING --(3)--> SENDING_REPLY --(4)--> [FINISHED]
|
||||
///
|
||||
/// (1) The `GrpcServer` creates a `ServerCall` and use it as the tag to accept requests
|
||||
/// gRPC `CompletionQueue`. Now the state is `PENDING`.
|
||||
/// (2) When a request is received, an event will be gotten from the `CompletionQueue`.
|
||||
/// `GrpcServer` then should change `ServerCall`'s state to PROCESSING and call
|
||||
/// `ServerCall::HandleRequest`.
|
||||
/// (3) When the `ServiceHandler` finishes handling the request, `ServerCallImpl::Finish`
|
||||
/// will be called, and the state becomes `SENDING_REPLY`.
|
||||
/// (4) When the reply is sent, an event will be gotten from the `CompletionQueue`.
|
||||
/// `GrpcServer` will then delete this call.
|
||||
///
|
||||
/// NOTE(hchen): Compared to `ServerCallImpl`, this abstract interface doesn't use
|
||||
/// template. This allows the users (e.g., `GrpcServer`) not having to use
|
||||
/// template as well.
|
||||
class ServerCall {
|
||||
public:
|
||||
/// Get the state of this `ServerCall`.
|
||||
virtual ServerCallState GetState() const = 0;
|
||||
|
||||
/// Set state of this `ServerCall`.
|
||||
virtual void SetState(const ServerCallState &new_state) = 0;
|
||||
|
||||
/// Handle the requst. This is the callback function to be called by
|
||||
/// `GrpcServer` when the request is received.
|
||||
virtual void HandleRequest() = 0;
|
||||
|
||||
/// Get the factory that created this `ServerCall`.
|
||||
virtual const ServerCallFactory &GetFactory() const = 0;
|
||||
};
|
||||
|
||||
/// The factory that creates a particular kind of `ServerCall` objects.
|
||||
class ServerCallFactory {
|
||||
public:
|
||||
/// Create a new `ServerCall` and request gRPC runtime to start accepting the
|
||||
/// corresonding type of requests.
|
||||
///
|
||||
/// \return Pointer to the `ServerCall` object.
|
||||
virtual ServerCall *CreateCall() const = 0;
|
||||
};
|
||||
|
||||
/// Represents the generic signature of a `FooServiceHandler::HandleBar()`
|
||||
/// function, where `Foo` is the service name and `Bar` is the rpc method name.
|
||||
///
|
||||
/// \tparam ServiceHandler Type of the handler that handles the request.
|
||||
/// \tparam Request Type of the request message.
|
||||
/// \tparam Reply Type of the reply message.
|
||||
template <class ServiceHandler, class Request, class Reply>
|
||||
using HandleRequestFunction = void (ServiceHandler::*)(const Request &, Reply *,
|
||||
RequestDoneCallback);
|
||||
|
||||
/// Implementation of `ServerCall`. It represents `ServerCall` for a particular
|
||||
/// RPC method.
|
||||
///
|
||||
/// \tparam ServiceHandler Type of the handler that handles the request.
|
||||
/// \tparam Request Type of the request message.
|
||||
/// \tparam Reply Type of the reply message.
|
||||
template <class ServiceHandler, class Request, class Reply>
|
||||
class ServerCallImpl : public ServerCall {
|
||||
public:
|
||||
/// Constructor.
|
||||
///
|
||||
/// \param[in] factory The factory which created this call.
|
||||
/// \param[in] service_handler The service handler that handles the request.
|
||||
/// \param[in] handle_request_function Pointer to the service handler function.
|
||||
ServerCallImpl(
|
||||
const ServerCallFactory &factory, ServiceHandler &service_handler,
|
||||
HandleRequestFunction<ServiceHandler, Request, Reply> handle_request_function)
|
||||
: state_(ServerCallState::PENDING),
|
||||
factory_(factory),
|
||||
service_handler_(service_handler),
|
||||
handle_request_function_(handle_request_function),
|
||||
response_writer_(&context_) {}
|
||||
|
||||
ServerCallState GetState() const override { return state_; }
|
||||
|
||||
void SetState(const ServerCallState &new_state) override { state_ = new_state; }
|
||||
|
||||
void HandleRequest() override {
|
||||
state_ = ServerCallState::PROCESSING;
|
||||
(service_handler_.*handle_request_function_)(request_, &reply_,
|
||||
[this](Status status) {
|
||||
// When the handler is done with the
|
||||
// request, tell gRPC to finish this
|
||||
// request.
|
||||
SendReply(status);
|
||||
});
|
||||
}
|
||||
|
||||
const ServerCallFactory &GetFactory() const override { return factory_; }
|
||||
|
||||
private:
|
||||
/// Tell gRPC to finish this request.
|
||||
void SendReply(Status status) {
|
||||
state_ = ServerCallState::SENDING_REPLY;
|
||||
response_writer_.Finish(reply_, RayStatusToGrpcStatus(status), this);
|
||||
}
|
||||
|
||||
/// State of this call.
|
||||
ServerCallState state_;
|
||||
|
||||
/// The factory which created this call.
|
||||
const ServerCallFactory &factory_;
|
||||
|
||||
/// The service handler that handles the request.
|
||||
ServiceHandler &service_handler_;
|
||||
|
||||
/// Pointer to the service handler function.
|
||||
HandleRequestFunction<ServiceHandler, Request, Reply> handle_request_function_;
|
||||
|
||||
/// Context for the request, allowing to tweak aspects of it such as the use
|
||||
/// of compression, authentication, as well as to send metadata back to the client.
|
||||
grpc::ServerContext context_;
|
||||
|
||||
/// The reponse writer.
|
||||
grpc::ServerAsyncResponseWriter<Reply> response_writer_;
|
||||
|
||||
/// The request message.
|
||||
Request request_;
|
||||
|
||||
/// The reply message.
|
||||
Reply reply_;
|
||||
|
||||
template <class T1, class T2, class T3, class T4>
|
||||
friend class ServerCallFactoryImpl;
|
||||
};
|
||||
|
||||
/// Represents the generic signature of a `FooService::AsyncService::RequestBar()`
|
||||
/// function, where `Foo` is the service name and `Bar` is the rpc method name.
|
||||
/// \tparam GrpcService Type of the gRPC-generated service class.
|
||||
/// \tparam Request Type of the request message.
|
||||
/// \tparam Reply Type of the reply message.
|
||||
template <class GrpcService, class Request, class Reply>
|
||||
using RequestCallFunction = void (GrpcService::AsyncService::*)(
|
||||
grpc::ServerContext *, Request *, grpc::ServerAsyncResponseWriter<Reply> *,
|
||||
grpc::CompletionQueue *, grpc::ServerCompletionQueue *, void *);
|
||||
|
||||
/// Implementation of `ServerCallFactory`
|
||||
///
|
||||
/// \tparam GrpcService Type of the gRPC-generated service class.
|
||||
/// \tparam ServiceHandler Type of the handler that handles the request.
|
||||
/// \tparam Request Type of the request message.
|
||||
/// \tparam Reply Type of the reply message.
|
||||
template <class GrpcService, class ServiceHandler, class Request, class Reply>
|
||||
class ServerCallFactoryImpl : public ServerCallFactory {
|
||||
using AsyncService = typename GrpcService::AsyncService;
|
||||
|
||||
public:
|
||||
/// Constructor.
|
||||
///
|
||||
/// \param[in] service The gRPC-generated `AsyncService`.
|
||||
/// \param[in] request_call_function Pointer to the `AsyncService::RequestMethod`
|
||||
// function.
|
||||
/// \param[in] service_handler The service handler that handles the request.
|
||||
/// \param[in] handle_request_function Pointer to the service handler function.
|
||||
/// \param[in] cq The `CompletionQueue`.
|
||||
ServerCallFactoryImpl(
|
||||
AsyncService &service,
|
||||
RequestCallFunction<GrpcService, Request, Reply> request_call_function,
|
||||
ServiceHandler &service_handler,
|
||||
HandleRequestFunction<ServiceHandler, Request, Reply> handle_request_function,
|
||||
const std::unique_ptr<grpc::ServerCompletionQueue> &cq)
|
||||
: service_(service),
|
||||
request_call_function_(request_call_function),
|
||||
service_handler_(service_handler),
|
||||
handle_request_function_(handle_request_function),
|
||||
cq_(cq) {}
|
||||
|
||||
ServerCall *CreateCall() const override {
|
||||
// Create a new `ServerCall`. This object will eventually be deleted by
|
||||
// `GrpcServer::PollEventsFromCompletionQueue`.
|
||||
auto call = new ServerCallImpl<ServiceHandler, Request, Reply>(
|
||||
*this, service_handler_, handle_request_function_);
|
||||
/// Request gRPC runtime to starting accepting this kind of request, using the call as
|
||||
/// the tag.
|
||||
(service_.*request_call_function_)(&call->context_, &call->request_,
|
||||
&call->response_writer_, cq_.get(), cq_.get(),
|
||||
call);
|
||||
return call;
|
||||
}
|
||||
|
||||
private:
|
||||
/// The gRPC-generated `AsyncService`.
|
||||
AsyncService &service_;
|
||||
|
||||
/// Pointer to the `AsyncService::RequestMethod` function.
|
||||
RequestCallFunction<GrpcService, Request, Reply> request_call_function_;
|
||||
|
||||
/// The service handler that handles the request.
|
||||
ServiceHandler &service_handler_;
|
||||
|
||||
/// Pointer to the service handler function.
|
||||
HandleRequestFunction<ServiceHandler, Request, Reply> handle_request_function_;
|
||||
|
||||
/// The `CompletionQueue`.
|
||||
const std::unique_ptr<grpc::ServerCompletionQueue> &cq_;
|
||||
};
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,33 @@
|
||||
#ifndef RAY_RPC_UTIL_H
|
||||
#define RAY_RPC_UTIL_H
|
||||
|
||||
#include <grpcpp/grpcpp.h>
|
||||
|
||||
#include "ray/common/status.h"
|
||||
|
||||
namespace ray {
|
||||
namespace rpc {
|
||||
|
||||
/// Helper function that converts a ray status to gRPC status.
|
||||
inline grpc::Status RayStatusToGrpcStatus(const Status &ray_status) {
|
||||
if (ray_status.ok()) {
|
||||
return grpc::Status::OK;
|
||||
} else {
|
||||
// TODO(hchen): Use more specific error code.
|
||||
return grpc::Status(grpc::StatusCode::UNKNOWN, ray_status.message());
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper function that converts a gRPC status to ray status.
|
||||
inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) {
|
||||
if (grpc_status.ok()) {
|
||||
return Status::OK();
|
||||
} else {
|
||||
return Status::IOError(grpc_status.error_message());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rpc
|
||||
} // namespace ray
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user