[gRPC] Use gRPC for inter-node-manager communication (#4968)

This commit is contained in:
Hao Chen
2019-06-17 19:00:50 +08:00
committed by GitHub
parent b08765a08b
commit 2bf92e02e2
24 changed files with 954 additions and 214 deletions
+2
View File
@@ -2,3 +2,5 @@
build --compilation_mode=opt
build --action_env=PATH
build --action_env=PYTHON_BIN_PATH
# This workaround is needed due to https://github.com/bazelbuild/bazel/issues/4341
build --per_file_copt="external/com_github_grpc_grpc/.*@-DGRPC_BAZEL_BUILD"
+1 -1
View File
@@ -126,7 +126,7 @@ matrix:
- ./ci/suppress_output ./ci/travis/install-dependencies.sh
# This command should be kept in sync with ray/python/README-building-wheels.md.
- ./python/build-wheel-macos.sh
- ./ci/suppress_output ./python/build-wheel-macos.sh
script:
- if [ $RAY_CI_MACOS_WHEELS_AFFECTED != "1" ]; then exit; fi
+26
View File
@@ -1,12 +1,37 @@
# Bazel build
# C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html
load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library")
load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
load("@//bazel:ray.bzl", "flatbuffer_py_library")
load("@//bazel:cython_library.bzl", "pyx_library")
COPTS = ["-DRAY_USE_GLOG"]
# Node manager gRPC lib.
grpc_proto_library(
name = "node_manager_grpc_lib",
srcs = ["src/ray/protobuf/node_manager.proto"],
)
# Node manager server and client.
cc_library(
name = "node_manager_rpc_lib",
srcs = glob([
"src/ray/rpc/*.cc",
]),
hdrs = glob([
"src/ray/rpc/*.h",
]),
copts = COPTS,
deps = [
":node_manager_grpc_lib",
":ray_common",
"@boost//:asio",
"@com_github_grpc_grpc//:grpc++",
],
)
cc_binary(
name = "raylet",
srcs = ["src/ray/raylet/main.cc"],
@@ -89,6 +114,7 @@ cc_library(
":gcs",
":gcs_fbs",
":node_manager_fbs",
":node_manager_rpc_lib",
":object_manager",
":ray_common",
":ray_util",
+4
View File
@@ -3,6 +3,8 @@ load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
load("@com_github_jupp0r_prometheus_cpp//:repositories.bzl", "prometheus_cpp_repositories")
load("@com_github_ray_project_ray//bazel:python_configure.bzl", "python_configure")
load("@com_github_checkstyle_java//:repo.bzl", "checkstyle_deps")
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
def ray_deps_build_all():
gen_java_deps()
@@ -10,3 +12,5 @@ def ray_deps_build_all():
boost_deps()
prometheus_cpp_repositories()
python_configure(name = "local_config_python")
grpc_deps()
+8
View File
@@ -101,3 +101,11 @@ def ray_deps_setup():
# `https://github.com/jupp0r/prometheus-cpp/pull/225` getting merged.
urls = ["https://github.com/jovany-wang/prometheus-cpp/archive/master.zip"],
)
http_archive(
name = "com_github_grpc_grpc",
urls = [
"https://github.com/grpc/grpc/archive/7741e806a213cba63c96234f16d712a8aa101a49.tar.gz",
],
strip_prefix = "grpc-7741e806a213cba63c96234f16d712a8aa101a49",
)
+1 -1
View File
@@ -23,7 +23,7 @@ time "$@" >$TMPFILE 2>&1
CODE=$?
if [ $CODE != 0 ]; then
cat $TMPFILE
tail -n 2000 $TMPFILE
echo "FAILED $CODE"
kill $WATCHDOG_PID
exit $CODE
+1 -1
View File
@@ -16,7 +16,7 @@ else
exit 1
fi
URL="https://github.com/bazelbuild/bazel/releases/download/0.21.0/bazel-0.21.0-installer-${platform}-x86_64.sh"
URL="https://github.com/bazelbuild/bazel/releases/download/0.26.1/bazel-0.26.1-installer-${platform}-x86_64.sh"
wget -O install.sh $URL
chmod +x install.sh
./install.sh --user
+5 -2
View File
@@ -94,11 +94,14 @@ define_java_module(
":org_ray_ray_api",
":org_ray_ray_runtime",
"@plasma//:org_apache_arrow_arrow_plasma",
"@maven//:com_google_guava_guava",
"@maven//:com_sun_xml_bind_jaxb_core",
"@maven//:com_sun_xml_bind_jaxb_impl",
"@maven//:commons_io_commons_io",
"@maven//:javax_xml_bind_jaxb_api",
"@maven//:org_apache_commons_commons_lang3",
"@maven//:org_slf4j_slf4j_api",
"@maven//:org_testng_testng",
"@maven//:com_google_guava_guava",
"@maven//:commons_io_commons_io",
],
)
+15
View File
@@ -32,11 +32,26 @@
<artifactId>guava</artifactId>
<version>27.0.1-jre</version>
</dependency>
<dependency>
<groupId>com.sun.xml.bind</groupId>
<artifactId>jaxb-core</artifactId>
<version>2.3.0</version>
</dependency>
<dependency>
<groupId>com.sun.xml.bind</groupId>
<artifactId>jaxb-impl</artifactId>
<version>2.3.0</version>
</dependency>
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.5</version>
</dependency>
<dependency>
<groupId>javax.xml.bind</groupId>
<artifactId>jaxb-api</artifactId>
<version>2.3.0</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
@@ -6,6 +6,7 @@ import numpy as np
import pytest
import ray
from ray.tests.conftest import _ray_start_cluster
num_tasks_submitted = [10**n for n in range(0, 6)]
num_tasks_ids = ["{}_tasks".format(i) for i in num_tasks_submitted]
@@ -41,3 +42,25 @@ def test_task_submission(benchmark, num_tasks):
warmup()
benchmark(benchmark_task_submission, num_tasks)
ray.shutdown()
def benchmark_task_forward(f, num_tasks):
ray.get([f.remote() for _ in range(num_tasks)])
@pytest.mark.benchmark
@pytest.mark.parametrize(
"num_tasks", [10**3, 10**4],
ids=[str(num) + "_tasks" for num in [10**3, 10**4]])
def test_task_forward(benchmark, num_tasks):
with _ray_start_cluster(num_cpus=16, object_store_memory=10**6) as cluster:
cluster.add_node(resources={"my_resource": 100})
ray.init(redis_address=cluster.redis_address)
@ray.remote(resources={"my_resource": 0.001})
def f():
return 1
# Warm up
ray.get([f.remote() for _ in range(100)])
benchmark(benchmark_task_forward, f, num_tasks)
+24
View File
@@ -0,0 +1,24 @@
syntax = "proto3";
package ray.rpc;
message ForwardTaskRequest {
// The ID of the task to be forwarded.
bytes task_id = 1;
// The tasks in the uncommitted lineage of the forwarded task. This
// should include task_id.
// TODO(hchen): Currently, `uncommitted_tasks` are represented as
// flatbutters-serialized bytes. This is because the flatbuffers-defined Task data
// structure is being used in many places. We should move Task and all related data
// strucutres to protobuf.
repeated bytes uncommitted_tasks = 2;
}
message ForwardTaskReply {
}
// Service for inter-node-manager communication.
service NodeManagerService {
// Forward a task and its uncommitted lineage to the remote node manager.
rpc ForwardTask(ForwardTaskRequest) returns (ForwardTaskReply);
}
+87 -144
View File
@@ -99,9 +99,9 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
lineage_cache_(gcs_client_->client_table().GetLocalClientId(),
gcs_client_->raylet_task_table(), gcs_client_->raylet_task_table(),
config.max_lineage_size),
remote_clients_(),
remote_server_connections_(),
actor_registry_() {
actor_registry_(),
node_manager_server_(config.node_manager_port, io_service, *this),
client_call_manager_(io_service) {
RAY_CHECK(heartbeat_period_.count() > 0);
// Initialize the resource map with own cluster resource configuration.
ClientID local_client_id = gcs_client_->client_table().GetLocalClientId();
@@ -117,6 +117,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
[this](const ObjectID &object_id) { HandleObjectMissing(object_id); }));
RAY_ARROW_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str()));
// Run the node manger rpc server.
node_manager_server_.Run();
}
ray::Status NodeManager::RegisterGcs() {
@@ -366,66 +368,24 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) {
return;
}
// TODO(atumanov): make remote client lookup O(1)
if (std::find(remote_clients_.begin(), remote_clients_.end(), client_id) ==
remote_clients_.end()) {
remote_clients_.push_back(client_id);
} else {
// NodeManager connection to this client was already established.
RAY_LOG(DEBUG) << "Received a new client connection that already exists: "
auto entry = remote_node_manager_clients_.find(client_id);
if (entry != remote_node_manager_clients_.end()) {
RAY_LOG(DEBUG) << "Received notification of a new client that already exists: "
<< client_id;
return;
}
// Establish a new NodeManager connection to this GCS client.
auto status = ConnectRemoteNodeManager(client_id, client_data.node_manager_address,
client_data.node_manager_port);
if (!status.ok()) {
// This is not a fatal error for raylet, but it should not happen.
// We need to broadcase this message.
std::string type = "raylet_connection_error";
std::ostringstream error_message;
error_message << "Failed to connect to ray node " << client_id
<< " with status: " << status.ToString()
<< ". This may be since the node was recently removed.";
// We use the nil DriverID to broadcast the message to all drivers.
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(
DriverID::Nil(), type, error_message.str(), current_time_ms()));
return;
}
// Initialize a rpc client to the new node manager.
std::unique_ptr<rpc::NodeManagerClient> client(
new rpc::NodeManagerClient(client_data.node_manager_address,
client_data.node_manager_port, client_call_manager_));
remote_node_manager_clients_.emplace(client_id, std::move(client));
ResourceSet resources_total(client_data.resources_total_label,
client_data.resources_total_capacity);
cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total));
}
ray::Status NodeManager::ConnectRemoteNodeManager(const ClientID &client_id,
const std::string &client_address,
int32_t client_port) {
// Establish a new NodeManager connection to this GCS client.
RAY_LOG(INFO) << "[ConnectClient] Trying to connect to client " << client_id << " at "
<< client_address << ":" << client_port;
boost::asio::ip::tcp::socket socket(io_service_);
RAY_RETURN_NOT_OK(TcpConnect(socket, client_address, client_port));
// The client is connected, now send a connect message to remote node manager.
auto server_conn = TcpServerConnection::Create(std::move(socket));
// Prepare client connection info buffer
flatbuffers::FlatBufferBuilder fbb;
auto message = protocol::CreateConnectClient(fbb, to_flatbuf(fbb, client_id_));
fbb.Finish(message);
// Send synchronously.
// TODO(swang): Make this a WriteMessageAsync.
RAY_RETURN_NOT_OK(server_conn->WriteMessage(
static_cast<int64_t>(protocol::MessageType::ConnectClient), fbb.GetSize(),
fbb.GetBufferPointer()));
remote_server_connections_.emplace(client_id, std::move(server_conn));
return ray::Status::OK();
}
void NodeManager::ClientRemoved(const ClientTableDataT &client_data) {
// TODO(swang): If we receive a notification for our own death, clean up and
// exit immediately.
@@ -440,17 +400,13 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) {
// check that it is actually removed, or log a warning otherwise, but that may
// not be necessary.
// Remove the client from the list of remote clients.
std::remove(remote_clients_.begin(), remote_clients_.end(), client_id);
// Remove the client from the resource map.
cluster_resource_map_.erase(client_id);
// Remove the remote server connection.
const auto connection_entry = remote_server_connections_.find(client_id);
if (connection_entry != remote_server_connections_.end()) {
connection_entry->second->Close();
remote_server_connections_.erase(connection_entry);
// Remove the node manager client.
const auto client_entry = remote_node_manager_clients_.find(client_id);
if (client_entry != remote_node_manager_clients_.end()) {
remote_node_manager_clients_.erase(client_entry);
} else {
RAY_LOG(WARNING) << "Received ClientRemoved callback for an unknown client "
<< client_id << ".";
@@ -1241,41 +1197,24 @@ void NodeManager::ProcessNewNodeManager(TcpClientConnection &node_manager_client
node_manager_client.ProcessMessages();
}
void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_client,
int64_t message_type,
const uint8_t *message_data) {
const auto message_type_value = static_cast<protocol::MessageType>(message_type);
RAY_LOG(DEBUG) << "[NodeManager] Message "
<< protocol::EnumNameMessageType(message_type_value) << "("
<< message_type << ") from node manager";
switch (message_type_value) {
case protocol::MessageType::ConnectClient: {
auto message = flatbuffers::GetRoot<protocol::ConnectClient>(message_data);
auto client_id = from_flatbuf<ClientID>(*message->client_id());
node_manager_client.SetClientID(client_id);
} break;
case protocol::MessageType::ForwardTaskRequest: {
auto message = flatbuffers::GetRoot<protocol::ForwardTaskRequest>(message_data);
TaskID task_id = from_flatbuf<TaskID>(*message->task_id());
Lineage uncommitted_lineage(*message);
const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData();
RAY_LOG(DEBUG) << "Received forwarded task " << task.GetTaskSpecification().TaskId()
<< " on node " << gcs_client_->client_table().GetLocalClientId()
<< " spillback=" << task.GetTaskExecutionSpec().NumForwards();
SubmitTask(task, uncommitted_lineage, /* forwarded = */ true);
} break;
case protocol::MessageType::DisconnectClient: {
// TODO(rkn): We need to do some cleanup here.
RAY_LOG(DEBUG) << "Received disconnect message from remote node manager. "
<< "We need to do some cleanup here.";
// Do not process any more messages from this node manager.
return;
} break;
default:
RAY_LOG(FATAL) << "Received unexpected message type " << message_type;
void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request,
rpc::ForwardTaskReply *reply,
rpc::RequestDoneCallback done_callback) {
// Get the forwarded task and its uncommitted lineage from the request.
TaskID task_id = TaskID::FromBinary(request.task_id());
Lineage uncommitted_lineage;
for (int i = 0; i < request.uncommitted_tasks_size(); i++) {
const std::string &task_message = request.uncommitted_tasks(i);
const Task task(*flatbuffers::GetRoot<protocol::Task>(
reinterpret_cast<const uint8_t *>(task_message.data())));
RAY_CHECK(uncommitted_lineage.SetEntry(std::move(task), GcsStatus::UNCOMMITTED));
}
node_manager_client.ProcessMessages();
const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData();
RAY_LOG(DEBUG) << "Received forwarded task " << task.GetTaskSpecification().TaskId()
<< " on node " << gcs_client_->client_table().GetLocalClientId()
<< " spillback=" << task.GetTaskExecutionSpec().NumForwards();
SubmitTask(task, uncommitted_lineage, /* forwarded = */ true);
done_callback(Status::OK());
}
void NodeManager::ProcessSetResourceRequest(
@@ -2253,6 +2192,16 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task,
void NodeManager::ForwardTask(
const Task &task, const ClientID &node_id,
const std::function<void(const ray::Status &, const Task &)> &on_error) {
// Lookup node manager client for this node_id and use it to send the request.
auto client_entry = remote_node_manager_clients_.find(node_id);
if (client_entry == remote_node_manager_clients_.end()) {
// TODO(atumanov): caller must handle failure to ensure tasks are not lost.
RAY_LOG(INFO) << "No node manager client found for GCS client id " << node_id;
on_error(ray::Status::IOError("Node manager client not found"), task);
return;
}
auto &client = client_entry->second;
const auto &spec = task.GetTaskSpecification();
auto task_id = spec.TaskId();
@@ -2272,68 +2221,61 @@ void NodeManager::ForwardTask(
// Increment forward count for the forwarded task.
lineage_cache_entry_task.IncrementNumForwards();
flatbuffers::FlatBufferBuilder fbb;
auto request = uncommitted_lineage.ToFlatbuffer(fbb, task_id);
fbb.Finish(request);
RAY_LOG(DEBUG) << "Forwarding task " << task_id << " from "
<< gcs_client_->client_table().GetLocalClientId() << " to " << node_id
<< " spillback="
<< lineage_cache_entry_task.GetTaskExecutionSpec().NumForwards();
// Lookup remote server connection for this node_id and use it to send the request.
auto it = remote_server_connections_.find(node_id);
if (it == remote_server_connections_.end()) {
// TODO(atumanov): caller must handle failure to ensure tasks are not lost.
RAY_LOG(INFO) << "No NodeManager connection found for GCS client id " << node_id;
on_error(ray::Status::IOError("NodeManager connection not found"), task);
return;
// Prepare the request message.
rpc::ForwardTaskRequest request;
request.set_task_id(task_id.Binary());
for (auto &entry : uncommitted_lineage.GetEntries()) {
request.add_uncommitted_tasks(entry.second.TaskData().Serialize());
}
auto &server_conn = it->second;
// Move the FORWARDING task to the SWAP queue so that we remember that we
// have it queued locally. Once the ForwardTaskRequest has been sent, the
// task will get re-queued, depending on whether the message succeeded or
// not.
local_queues_.QueueTasks({task}, TaskState::SWAP);
server_conn->WriteMessageAsync(
static_cast<int64_t>(protocol::MessageType::ForwardTaskRequest), fbb.GetSize(),
fbb.GetBufferPointer(), [this, on_error, task_id, node_id](ray::Status status) {
// Remove the FORWARDING task from the SWAP queue.
TaskState state;
const auto task = local_queues_.RemoveTask(task_id, &state);
RAY_CHECK(state == TaskState::SWAP);
client->ForwardTask(request, [this, on_error, task_id, node_id](
Status status, const rpc::ForwardTaskReply &reply) {
// Remove the FORWARDING task from the SWAP queue.
TaskState state;
const auto task = local_queues_.RemoveTask(task_id, &state);
RAY_CHECK(state == TaskState::SWAP);
if (status.ok()) {
const auto &spec = task.GetTaskSpecification();
// Mark as forwarded so that the task and its lineage are not
// re-forwarded in the future to the receiving node.
lineage_cache_.MarkTaskAsForwarded(task_id, node_id);
if (status.ok()) {
const auto &spec = task.GetTaskSpecification();
// Mark as forwarded so that the task and its lineage are not
// re-forwarded in the future to the receiving node.
lineage_cache_.MarkTaskAsForwarded(task_id, node_id);
// Notify the task dependency manager that we are no longer responsible
// for executing this task.
task_dependency_manager_.TaskCanceled(task_id);
// Preemptively push any local arguments to the receiving node. For now, we
// only do this with actor tasks, since actor tasks must be executed by a
// specific process and therefore have affinity to the receiving node.
if (spec.IsActorTask()) {
// Iterate through the object's arguments. NOTE(swang): We do not include
// the execution dependencies here since those cannot be transferred
// between nodes.
for (int i = 0; i < spec.NumArgs(); ++i) {
int count = spec.ArgIdCount(i);
for (int j = 0; j < count; j++) {
ObjectID argument_id = spec.ArgId(i, j);
// If the argument is local, then push it to the receiving node.
if (task_dependency_manager_.CheckObjectLocal(argument_id)) {
object_manager_.Push(argument_id, node_id);
}
}
// Notify the task dependency manager that we are no longer responsible
// for executing this task.
task_dependency_manager_.TaskCanceled(task_id);
// Preemptively push any local arguments to the receiving node. For now, we
// only do this with actor tasks, since actor tasks must be executed by a
// specific process and therefore have affinity to the receiving node.
if (spec.IsActorTask()) {
// Iterate through the object's arguments. NOTE(swang): We do not include
// the execution dependencies here since those cannot be transferred
// between nodes.
for (int i = 0; i < spec.NumArgs(); ++i) {
int count = spec.ArgIdCount(i);
for (int j = 0; j < count; j++) {
ObjectID argument_id = spec.ArgId(i, j);
// If the argument is local, then push it to the receiving node.
if (task_dependency_manager_.CheckObjectLocal(argument_id)) {
object_manager_.Push(argument_id, node_id);
}
}
} else {
on_error(status, task);
}
});
}
} else {
on_error(status, task);
}
});
}
void NodeManager::DumpDebugState() const {
@@ -2368,10 +2310,11 @@ std::string NodeManager::DebugString() const {
result << "\n- num dead actors: " << statistical_data.dead_actors;
result << "\n- max num handles: " << statistical_data.max_num_handles;
result << "\nRemoteConnections:";
for (auto &pair : remote_server_connections_) {
result << "\n" << pair.first.Hex() << ": " << pair.second->DebugString();
result << "\nRemote node manager clients: ";
for (const auto &entry : remote_node_manager_clients_) {
result << "\n" << entry.first;
}
result << "\nDebugString() time ms: " << (current_time_ms() - now_ms);
return result.str();
}
+21 -22
View File
@@ -4,6 +4,9 @@
#include <boost/asio/steady_timer.hpp>
// clang-format off
#include "ray/rpc/client_call.h"
#include "ray/rpc/node_manager_server.h"
#include "ray/rpc/node_manager_client.h"
#include "ray/raylet/task.h"
#include "ray/object_manager/object_manager.h"
#include "ray/common/client_connection.h"
@@ -52,7 +55,7 @@ struct NodeManagerConfig {
std::string session_dir;
};
class NodeManager {
class NodeManager : public rpc::NodeManagerServiceHandler {
public:
/// Create a node manager.
///
@@ -86,15 +89,6 @@ class NodeManager {
/// \return Void.
void ProcessNewNodeManager(TcpClientConnection &node_manager_client);
/// Handle a message from a remote node manager.
///
/// \param node_manager_client The connection to the remote node manager.
/// \param message_type The type of the message.
/// \param message The message contents.
/// \return Void.
void ProcessNodeManagerMessage(TcpClientConnection &node_manager_client,
int64_t message_type, const uint8_t *message);
/// Subscribe to the relevant GCS tables and set up handlers.
///
/// \return Status indicating whether this was done successfully or not.
@@ -108,6 +102,9 @@ class NodeManager {
/// Record metrics.
void RecordMetrics() const;
/// Get the port of the node manager rpc server.
int GetServerPort() const { return node_manager_server_.GetPort(); }
private:
/// Methods for handling clients.
@@ -450,15 +447,10 @@ class NodeManager {
void HandleDisconnectedActor(const ActorID &actor_id, bool was_local,
bool intentional_disconnect);
/// connect to a remote node manager.
///
/// \param client_id The client ID for the remote node manager.
/// \param client_address The IP address for the remote node manager.
/// \param client_port The listening port for the remote node manager.
/// \return True if the connect succeeds.
ray::Status ConnectRemoteNodeManager(const ClientID &client_id,
const std::string &client_address,
int32_t client_port);
/// Handle a `ForwardTask` request.
void HandleForwardTask(const rpc::ForwardTaskRequest &request,
rpc::ForwardTaskReply *reply,
rpc::RequestDoneCallback done_callback) override;
// GCS client ID for this node.
ClientID client_id_;
@@ -505,9 +497,6 @@ class NodeManager {
TaskDependencyManager task_dependency_manager_;
/// The lineage cache for the GCS object and task tables.
LineageCache lineage_cache_;
std::vector<ClientID> remote_clients_;
std::unordered_map<ClientID, std::shared_ptr<TcpServerConnection>>
remote_server_connections_;
/// A mapping from actor ID to registration information about that actor
/// (including which node manager owns it).
std::unordered_map<ActorID, ActorRegistration> actor_registry_;
@@ -515,6 +504,16 @@ class NodeManager {
/// This map stores actor ID to the ID of the checkpoint that will be used to
/// restore the actor.
std::unordered_map<ActorID, ActorCheckpointID> checkpoint_id_to_restore_;
/// The RPC server.
rpc::NodeManagerServer node_manager_server_;
/// The `ClientCallManager` object that is shared by all `NodeManagerClient`s.
rpc::ClientCallManager client_call_manager_;
/// Map from node ids to clients of the remote node managers.
std::unordered_map<ClientID, std::unique_ptr<rpc::NodeManagerClient>>
remote_node_manager_clients_;
};
} // namespace raylet
+2 -34
View File
@@ -61,15 +61,10 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_
main_service,
boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(),
object_manager_config.object_manager_port)),
object_manager_socket_(main_service),
node_manager_acceptor_(main_service, boost::asio::ip::tcp::endpoint(
boost::asio::ip::tcp::v4(),
node_manager_config.node_manager_port)),
node_manager_socket_(main_service) {
object_manager_socket_(main_service) {
// Start listening for clients.
DoAccept();
DoAcceptObjectManager();
DoAcceptNodeManager();
RAY_CHECK_OK(RegisterGcs(
node_ip_address, socket_name_, object_manager_config.store_socket_name,
@@ -100,7 +95,7 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address,
client_info.raylet_socket_name = raylet_socket_name;
client_info.object_store_socket_name = object_store_socket_name;
client_info.object_manager_port = object_manager_acceptor_.local_endpoint().port();
client_info.node_manager_port = node_manager_acceptor_.local_endpoint().port();
client_info.node_manager_port = node_manager_.GetServerPort();
// Add resource information.
for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) {
client_info.resources_total_label.push_back(resource_pair.first);
@@ -120,33 +115,6 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address,
return Status::OK();
}
void Raylet::DoAcceptNodeManager() {
node_manager_acceptor_.async_accept(node_manager_socket_,
boost::bind(&Raylet::HandleAcceptNodeManager, this,
boost::asio::placeholders::error));
}
void Raylet::HandleAcceptNodeManager(const boost::system::error_code &error) {
if (!error) {
ClientHandler<boost::asio::ip::tcp> client_handler =
[this](TcpClientConnection &client) {
node_manager_.ProcessNewNodeManager(client);
};
MessageHandler<boost::asio::ip::tcp> message_handler =
[this](std::shared_ptr<TcpClientConnection> client, int64_t message_type,
const uint8_t *message) {
node_manager_.ProcessNodeManagerMessage(*client, message_type, message);
};
// Accept a new TCP client and dispatch it to the node manager.
auto new_connection = TcpClientConnection::Create(
client_handler, message_handler, std::move(node_manager_socket_), "node manager",
node_manager_message_enum,
static_cast<int64_t>(protocol::MessageType::DisconnectClient));
}
// We're ready to accept another client.
DoAcceptNodeManager();
}
void Raylet::DoAcceptObjectManager() {
object_manager_acceptor_.async_accept(
object_manager_socket_, boost::bind(&Raylet::HandleAcceptObjectManager, this,
-6
View File
@@ -63,8 +63,6 @@ class Raylet {
void DoAcceptObjectManager();
/// Handle an accepted tcp client connection.
void HandleAcceptObjectManager(const boost::system::error_code &error);
void DoAcceptNodeManager();
void HandleAcceptNodeManager(const boost::system::error_code &error);
friend class TestObjectManagerIntegration;
@@ -88,10 +86,6 @@ class Raylet {
boost::asio::ip::tcp::acceptor object_manager_acceptor_;
/// The socket to listen on for new object manager tcp clients.
boost::asio::ip::tcp::socket object_manager_socket_;
/// An acceptor for new tcp clients.
boost::asio::ip::tcp::acceptor node_manager_acceptor_;
/// The socket to listen on for new tcp clients.
boost::asio::ip::tcp::socket node_manager_socket_;
};
} // namespace raylet
+7 -3
View File
@@ -46,14 +46,18 @@ void Task::CopyTaskExecutionSpec(const Task &task) {
ComputeDependencies();
}
const std::string Task::Serialize() const {
flatbuffers::FlatBufferBuilder fbb;
fbb.Finish(ToFlatbuffer(fbb));
return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize());
}
std::string SerializeTaskAsString(const std::vector<ObjectID> *dependencies,
const TaskSpecification *task_spec) {
flatbuffers::FlatBufferBuilder fbb;
std::vector<ObjectID> execution_dependencies(*dependencies);
TaskExecutionSpecification execution_spec(std::move(execution_dependencies));
Task task(execution_spec, *task_spec);
fbb.Finish(task.ToFlatbuffer(fbb));
return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize());
return task.Serialize();
}
} // namespace raylet
+3
View File
@@ -84,6 +84,9 @@ class Task {
/// \param task Task structure with updated dynamic information.
void CopyTaskExecutionSpec(const Task &task);
/// Serialize this task as a string.
const std::string Serialize() const;
private:
void ComputeDependencies();
+169
View File
@@ -0,0 +1,169 @@
#ifndef RAY_RPC_CLIENT_CALL_H
#define RAY_RPC_CLIENT_CALL_H
#include <grpcpp/grpcpp.h>
#include <boost/asio.hpp>
#include "ray/common/status.h"
#include "ray/rpc/util.h"
namespace ray {
namespace rpc {
/// Represents an outgoing gRPC request.
///
/// The lifecycle of a `ClientCall` is as follows.
///
/// When a client submits a new gRPC request, a new `ClientCall` object will be created
/// by `ClientCallMangager::CreateCall`. Then the object will be used as the tag of
/// `CompletionQueue`.
///
/// When the reply is received, `ClientCallMangager` will get the address of this object
/// via `CompletionQueue`'s tag. And the manager should call `OnReplyReceived` and then
/// delete this object.
///
/// NOTE(hchen): Compared to `ClientCallImpl`, this abstract interface doesn't use
/// template. This allows the users (e.g., `ClientCallMangager`) not having to use
/// template as well.
class ClientCall {
public:
/// The callback to be called by `ClientCallManager` when the reply of this request is
/// received.
virtual void OnReplyReceived() = 0;
};
class ClientCallManager;
/// Reprents the client callback function of a particular rpc method.
///
/// \tparam Reply Type of the reply message.
template <class Reply>
using ClientCallback = std::function<void(const Status &status, const Reply &reply)>;
/// Implementaion of the `ClientCall`. It represents a `ClientCall` for a particular
/// RPC method.
///
/// \tparam Reply Type of the Reply message.
template <class Reply>
class ClientCallImpl : public ClientCall {
public:
void OnReplyReceived() override { callback_(GrpcStatusToRayStatus(status_), reply_); }
private:
/// Constructor.
///
/// \param[in] callback The callback function to handle the reply.
ClientCallImpl(const ClientCallback<Reply> &callback) : callback_(callback) {}
/// The reply message.
Reply reply_;
/// The callback function to handle the reply.
ClientCallback<Reply> callback_;
/// The response reader.
std::unique_ptr<grpc::ClientAsyncResponseReader<Reply>> response_reader_;
/// gRPC status of this request.
grpc::Status status_;
/// Context for the client. It could be used to convey extra information to
/// the server and/or tweak certain RPC behaviors.
grpc::ClientContext context_;
friend class ClientCallManager;
};
/// Peprents the generic signature of a `FooService::Stub::PrepareAsyncBar`
/// function, where `Foo` is the service name and `Bar` is the rpc method name.
///
/// \tparam GrpcService Type of the gRPC-generated service class.
/// \tparam Request Type of the request message.
/// \tparam Reply Type of the reply message.
template <class GrpcService, class Request, class Reply>
using PrepareAsyncFunction = std::unique_ptr<grpc::ClientAsyncResponseReader<Reply>> (
GrpcService::Stub::*)(grpc::ClientContext *context, const Request &request,
grpc::CompletionQueue *cq);
/// `ClientCallManager` is used to manage outgoing gRPC requests and the lifecycles of
/// `ClientCall` objects.
///
/// It maintains a thread that keeps polling events from `CompletionQueue`, and post
/// the callback function to the main event loop when a reply is received.
///
/// Mutiple clients can share one `ClientCallManager`.
class ClientCallManager {
public:
/// Constructor.
///
/// \param[in] main_service The main event loop, to which the callback functions will be
/// posted.
ClientCallManager(boost::asio::io_service &main_service) : main_service_(main_service) {
// Start the polling thread.
std::thread polling_thread(&ClientCallManager::PollEventsFromCompletionQueue, this);
polling_thread.detach();
}
~ClientCallManager() { cq_.Shutdown(); }
/// Create a new `ClientCall` and send request.
///
/// \param[in] stub The gRPC-generated stub.
/// \param[in] prepare_async_function Pointer to the gRPC-generated
/// `FooService::Stub::PrepareAsyncBar` function.
/// \param[in] request The request message.
/// \param[in] callback The callback function that handles reply.
///
/// \tparam GrpcService Type of the gRPC-generated service class.
/// \tparam Request Type of the request message.
/// \tparam Reply Type of the reply message.
template <class GrpcService, class Request, class Reply>
ClientCall *CreateCall(
typename GrpcService::Stub &stub,
const PrepareAsyncFunction<GrpcService, Request, Reply> prepare_async_function,
const Request &request, const ClientCallback<Reply> &callback) {
// Create a new `ClientCall` object. This object will eventuall be deleted in the
// `ClientCallManager::PollEventsFromCompletionQueue` when reply is received.
auto call = new ClientCallImpl<Reply>(callback);
// Send request.
call->response_reader_ =
(stub.*prepare_async_function)(&call->context_, request, &cq_);
call->response_reader_->StartCall();
call->response_reader_->Finish(&call->reply_, &call->status_, (void *)call);
return call;
}
private:
/// This function runs in a background thread. It keeps polling events from the
/// `CompletionQueue`, and dispaches the event to the callbacks via the `ClientCall`
/// objects.
void PollEventsFromCompletionQueue() {
void *got_tag;
bool ok = false;
// Keep reading events from the `CompletionQueue` until it's shutdown.
while (cq_.Next(&got_tag, &ok)) {
ClientCall *call = reinterpret_cast<ClientCall *>(got_tag);
if (ok) {
// Post the callback to the main event loop.
main_service_.post([call]() {
call->OnReplyReceived();
// The call is finished, we can delete the `ClientCall` object now.
delete call;
});
} else {
delete call;
}
}
}
/// The main event loop, to which the callback functions will be posted.
boost::asio::io_service &main_service_;
/// The gRPC `CompletionQueue` object used to poll events.
grpc::CompletionQueue cq_;
};
} // namespace rpc
} // namespace ray
#endif
+70
View File
@@ -0,0 +1,70 @@
#include "ray/rpc/grpc_server.h"
namespace ray {
namespace rpc {
void GrpcServer::Run() {
std::string server_address("0.0.0.0:" + std::to_string(port_));
grpc::ServerBuilder builder;
// TODO(hchen): Add options for authentication.
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_);
// Allow subclasses to register concrete services.
RegisterServices(builder);
// Get hold of the completion queue used for the asynchronous communication
// with the gRPC runtime.
cq_ = builder.AddCompletionQueue();
// Build and start server.
server_ = builder.BuildAndStart();
RAY_LOG(DEBUG) << name_ << " server started, listening on port " << port_ << ".";
// Allow subclasses to initialize the server call factories.
InitServerCallFactories(&server_call_factories_and_concurrencies_);
for (auto &entry : server_call_factories_and_concurrencies_) {
for (int i = 0; i < entry.second; i++) {
// Create and request calls from the factory.
entry.first->CreateCall();
}
}
// Start a thread that polls incoming requests.
std::thread polling_thread(&GrpcServer::PollEventsFromCompletionQueue, this);
polling_thread.detach();
}
void GrpcServer::PollEventsFromCompletionQueue() {
void *tag;
bool ok;
// Keep reading events from the `CompletionQueue` until it's shutdown.
while (cq_->Next(&tag, &ok)) {
ServerCall *server_call = static_cast<ServerCall *>(tag);
// `ok == false` indicates that the server has been shut down.
// We should delete the call object in this case.
bool delete_call = !ok;
if (ok) {
switch (server_call->GetState()) {
case ServerCallState::PENDING:
// We've received a new incoming request. Now this call object is used to
// track this request. So we need to create another call to handle next
// incoming request.
server_call->GetFactory().CreateCall();
server_call->SetState(ServerCallState::PROCESSING);
main_service_.post([server_call] { server_call->HandleRequest(); });
break;
case ServerCallState::SENDING_REPLY:
// The reply has been sent, this call can be deleted now.
// This event is triggered by `ServerCallImpl::SendReply`.
delete_call = true;
break;
default:
RAY_LOG(FATAL) << "Shouldn't reach here.";
break;
}
}
if (delete_call) {
delete server_call;
}
}
}
} // namespace rpc
} // namespace ray
+92
View File
@@ -0,0 +1,92 @@
#ifndef RAY_RPC_GRPC_SERVER_H
#define RAY_RPC_GRPC_SERVER_H
#include <thread>
#include <grpcpp/grpcpp.h>
#include <boost/asio.hpp>
#include "ray/common/status.h"
#include "ray/rpc/server_call.h"
namespace ray {
namespace rpc {
/// Base class that represents an abstract gRPC server.
///
/// A `GrpcServer` listens on a specific port. It owns
/// 1) a `ServerCompletionQueue` that is used for polling events from gRPC,
/// 2) and a thread that polls events from the `ServerCompletionQueue`.
///
/// Subclasses can register one or multiple services to a `GrpcServer`, see
/// `RegisterServices`. And they should also implement `InitServerCallFactories` to decide
/// which kinds of requests this server should accept.
class GrpcServer {
public:
/// Constructor.
///
/// \param[in] name Name of this server, used for logging and debugging purpose.
/// \param[in] port The port to bind this server to. If it's 0, a random available port
/// will be chosen.
/// \param[in] main_service The main event loop, to which service handler functions
/// will be posted.
GrpcServer(const std::string &name, const uint32_t port,
boost::asio::io_service &main_service)
: name_(name), port_(port), main_service_(main_service) {}
/// Destruct this gRPC server.
~GrpcServer() {
server_->Shutdown();
cq_->Shutdown();
}
/// Initialize and run this server.
void Run();
/// Get the port of this gRPC server.
int GetPort() const { return port_; }
protected:
/// Subclasses should implement this method and register one or multiple gRPC services
/// to the given `ServerBuilder`.
///
/// \param[in] builder The `ServerBuilder` instance to register services to.
virtual void RegisterServices(grpc::ServerBuilder &builder) = 0;
/// Subclasses should implement this method to initialize the `ServerCallFactory`
/// instances, as well as specify maximum number of concurrent requests that gRPC
/// server can "accept" (not "handle"). Each factory will be used to create
/// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and
/// handle an incoming request.
///
/// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects,
/// and the maximum number of concurrent requests that gRPC server can accept.
virtual void InitServerCallFactories(
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
*server_call_factories_and_concurrencies) = 0;
/// This function runs in a background thread. It keeps polling events from the
/// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances
/// via the `ServerCall` objects.
void PollEventsFromCompletionQueue();
/// The main event loop, to which the service handler functions will be posted.
boost::asio::io_service &main_service_;
/// Name of this server, used for logging and debugging purpose.
const std::string name_;
/// Port of this server.
int port_;
/// The `ServerCallFactory` objects, and the maximum number of concurrent requests that
/// gRPC server can accept.
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
server_call_factories_and_concurrencies_;
/// The `ServerCompletionQueue` object used for polling events.
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
/// The `Server` object.
std::unique_ptr<grpc::Server> server_;
};
} // namespace rpc
} // namespace ray
#endif
+56
View File
@@ -0,0 +1,56 @@
#ifndef RAY_RPC_NODE_MANAGER_CLIENT_H
#define RAY_RPC_NODE_MANAGER_CLIENT_H
#include <thread>
#include <grpcpp/grpcpp.h>
#include "ray/common/status.h"
#include "ray/rpc/client_call.h"
#include "ray/util/logging.h"
#include "src/ray/protobuf/node_manager.grpc.pb.h"
#include "src/ray/protobuf/node_manager.pb.h"
namespace ray {
namespace rpc {
/// Client used for communicating with a remote node manager server.
class NodeManagerClient {
public:
/// Constructor.
///
/// \param[in] address Address of the node manager server.
/// \param[in] port Port of the node manager server.
/// \param[in] client_call_manager The `ClientCallManager` used for managing requests.
NodeManagerClient(const std::string &address, const int port,
ClientCallManager &client_call_manager)
: client_call_manager_(client_call_manager) {
std::shared_ptr<grpc::Channel> channel = grpc::CreateChannel(
address + ":" + std::to_string(port), grpc::InsecureChannelCredentials());
stub_ = NodeManagerService::NewStub(channel);
};
/// Forward a task and its uncommitted lineage.
///
/// \param[in] request The request message.
/// \param[in] callback The callback function that handles reply.
void ForwardTask(const ForwardTaskRequest &request,
const ClientCallback<ForwardTaskReply> &callback) {
client_call_manager_
.CreateCall<NodeManagerService, ForwardTaskRequest, ForwardTaskReply>(
*stub_, &NodeManagerService::Stub::PrepareAsyncForwardTask, request,
callback);
}
private:
/// The gRPC-generated stub.
std::unique_ptr<NodeManagerService::Stub> stub_;
/// The `ClientCallManager` used for managing requests.
ClientCallManager &client_call_manager_;
};
} // namespace rpc
} // namespace ray
#endif // RAY_RPC_NODE_MANAGER_CLIENT_H
+71
View File
@@ -0,0 +1,71 @@
#ifndef RAY_RPC_NODE_MANAGER_SERVER_H
#define RAY_RPC_NODE_MANAGER_SERVER_H
#include "ray/rpc/grpc_server.h"
#include "ray/rpc/server_call.h"
#include "src/ray/protobuf/node_manager.grpc.pb.h"
#include "src/ray/protobuf/node_manager.pb.h"
namespace ray {
namespace rpc {
/// Interface of the `NodeManagerService`, see `src/ray/protobuf/node_manager.proto`.
class NodeManagerServiceHandler {
public:
/// Handle a `ForwardTask` request.
/// The implementation can handle this request asynchronously. When hanling is done, the
/// `done_callback` should be called.
///
/// \param[in] request The request message.
/// \param[out] reply The reply message.
/// \param[in] done_callback The callback to be called when the request is done.
virtual void HandleForwardTask(const ForwardTaskRequest &request,
ForwardTaskReply *reply,
RequestDoneCallback done_callback) = 0;
};
/// The `GrpcServer` for `NodeManagerService`.
class NodeManagerServer : public GrpcServer {
public:
/// Constructor.
///
/// \param[in] port See super class.
/// \param[in] main_service See super class.
/// \param[in] handler The service handler that actually handle the requests.
NodeManagerServer(const uint32_t port, boost::asio::io_service &main_service,
NodeManagerServiceHandler &service_handler)
: GrpcServer("NodeManager", port, main_service),
service_handler_(service_handler){};
void RegisterServices(grpc::ServerBuilder &builder) override {
/// Register `NodeManagerService`.
builder.RegisterService(&service_);
}
void InitServerCallFactories(
std::vector<std::pair<std::unique_ptr<ServerCallFactory>, int>>
*server_call_factories_and_concurrencies) override {
// Initialize the factory for `ForwardTask` requests.
std::unique_ptr<ServerCallFactory> forward_task_call_factory(
new ServerCallFactoryImpl<NodeManagerService, NodeManagerServiceHandler,
ForwardTaskRequest, ForwardTaskReply>(
service_, &NodeManagerService::AsyncService::RequestForwardTask,
service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq_));
// Set `ForwardTask`'s accept concurrency to 100.
server_call_factories_and_concurrencies->emplace_back(
std::move(forward_task_call_factory), 100);
}
private:
/// The grpc async service object.
NodeManagerService::AsyncService service_;
/// The service handler that actually handle the requests.
NodeManagerServiceHandler &service_handler_;
};
} // namespace rpc
} // namespace ray
#endif
+233
View File
@@ -0,0 +1,233 @@
#ifndef RAY_RPC_SERVER_CALL_H
#define RAY_RPC_SERVER_CALL_H
#include <grpcpp/grpcpp.h>
#include "ray/common/status.h"
#include "ray/rpc/util.h"
namespace ray {
namespace rpc {
/// Represents the callback function to be called when a `ServiceHandler` finishes
/// handling a request.
using RequestDoneCallback = std::function<void(Status)>;
/// Represents state of a `ServerCall`.
enum class ServerCallState {
/// The call is created and waiting for an incoming request.
PENDING,
/// Request is received and being processed.
PROCESSING,
/// Request processing is done, and reply is being sent to client.
SENDING_REPLY
};
class ServerCallFactory;
/// Reprensents an incoming request of a gRPC server.
///
/// The lifecycle and state transition of a `ServerCall` is as follows:
///
/// --(1)--> PENDING --(2)--> PROCESSING --(3)--> SENDING_REPLY --(4)--> [FINISHED]
///
/// (1) The `GrpcServer` creates a `ServerCall` and use it as the tag to accept requests
/// gRPC `CompletionQueue`. Now the state is `PENDING`.
/// (2) When a request is received, an event will be gotten from the `CompletionQueue`.
/// `GrpcServer` then should change `ServerCall`'s state to PROCESSING and call
/// `ServerCall::HandleRequest`.
/// (3) When the `ServiceHandler` finishes handling the request, `ServerCallImpl::Finish`
/// will be called, and the state becomes `SENDING_REPLY`.
/// (4) When the reply is sent, an event will be gotten from the `CompletionQueue`.
/// `GrpcServer` will then delete this call.
///
/// NOTE(hchen): Compared to `ServerCallImpl`, this abstract interface doesn't use
/// template. This allows the users (e.g., `GrpcServer`) not having to use
/// template as well.
class ServerCall {
public:
/// Get the state of this `ServerCall`.
virtual ServerCallState GetState() const = 0;
/// Set state of this `ServerCall`.
virtual void SetState(const ServerCallState &new_state) = 0;
/// Handle the requst. This is the callback function to be called by
/// `GrpcServer` when the request is received.
virtual void HandleRequest() = 0;
/// Get the factory that created this `ServerCall`.
virtual const ServerCallFactory &GetFactory() const = 0;
};
/// The factory that creates a particular kind of `ServerCall` objects.
class ServerCallFactory {
public:
/// Create a new `ServerCall` and request gRPC runtime to start accepting the
/// corresonding type of requests.
///
/// \return Pointer to the `ServerCall` object.
virtual ServerCall *CreateCall() const = 0;
};
/// Represents the generic signature of a `FooServiceHandler::HandleBar()`
/// function, where `Foo` is the service name and `Bar` is the rpc method name.
///
/// \tparam ServiceHandler Type of the handler that handles the request.
/// \tparam Request Type of the request message.
/// \tparam Reply Type of the reply message.
template <class ServiceHandler, class Request, class Reply>
using HandleRequestFunction = void (ServiceHandler::*)(const Request &, Reply *,
RequestDoneCallback);
/// Implementation of `ServerCall`. It represents `ServerCall` for a particular
/// RPC method.
///
/// \tparam ServiceHandler Type of the handler that handles the request.
/// \tparam Request Type of the request message.
/// \tparam Reply Type of the reply message.
template <class ServiceHandler, class Request, class Reply>
class ServerCallImpl : public ServerCall {
public:
/// Constructor.
///
/// \param[in] factory The factory which created this call.
/// \param[in] service_handler The service handler that handles the request.
/// \param[in] handle_request_function Pointer to the service handler function.
ServerCallImpl(
const ServerCallFactory &factory, ServiceHandler &service_handler,
HandleRequestFunction<ServiceHandler, Request, Reply> handle_request_function)
: state_(ServerCallState::PENDING),
factory_(factory),
service_handler_(service_handler),
handle_request_function_(handle_request_function),
response_writer_(&context_) {}
ServerCallState GetState() const override { return state_; }
void SetState(const ServerCallState &new_state) override { state_ = new_state; }
void HandleRequest() override {
state_ = ServerCallState::PROCESSING;
(service_handler_.*handle_request_function_)(request_, &reply_,
[this](Status status) {
// When the handler is done with the
// request, tell gRPC to finish this
// request.
SendReply(status);
});
}
const ServerCallFactory &GetFactory() const override { return factory_; }
private:
/// Tell gRPC to finish this request.
void SendReply(Status status) {
state_ = ServerCallState::SENDING_REPLY;
response_writer_.Finish(reply_, RayStatusToGrpcStatus(status), this);
}
/// State of this call.
ServerCallState state_;
/// The factory which created this call.
const ServerCallFactory &factory_;
/// The service handler that handles the request.
ServiceHandler &service_handler_;
/// Pointer to the service handler function.
HandleRequestFunction<ServiceHandler, Request, Reply> handle_request_function_;
/// Context for the request, allowing to tweak aspects of it such as the use
/// of compression, authentication, as well as to send metadata back to the client.
grpc::ServerContext context_;
/// The reponse writer.
grpc::ServerAsyncResponseWriter<Reply> response_writer_;
/// The request message.
Request request_;
/// The reply message.
Reply reply_;
template <class T1, class T2, class T3, class T4>
friend class ServerCallFactoryImpl;
};
/// Represents the generic signature of a `FooService::AsyncService::RequestBar()`
/// function, where `Foo` is the service name and `Bar` is the rpc method name.
/// \tparam GrpcService Type of the gRPC-generated service class.
/// \tparam Request Type of the request message.
/// \tparam Reply Type of the reply message.
template <class GrpcService, class Request, class Reply>
using RequestCallFunction = void (GrpcService::AsyncService::*)(
grpc::ServerContext *, Request *, grpc::ServerAsyncResponseWriter<Reply> *,
grpc::CompletionQueue *, grpc::ServerCompletionQueue *, void *);
/// Implementation of `ServerCallFactory`
///
/// \tparam GrpcService Type of the gRPC-generated service class.
/// \tparam ServiceHandler Type of the handler that handles the request.
/// \tparam Request Type of the request message.
/// \tparam Reply Type of the reply message.
template <class GrpcService, class ServiceHandler, class Request, class Reply>
class ServerCallFactoryImpl : public ServerCallFactory {
using AsyncService = typename GrpcService::AsyncService;
public:
/// Constructor.
///
/// \param[in] service The gRPC-generated `AsyncService`.
/// \param[in] request_call_function Pointer to the `AsyncService::RequestMethod`
// function.
/// \param[in] service_handler The service handler that handles the request.
/// \param[in] handle_request_function Pointer to the service handler function.
/// \param[in] cq The `CompletionQueue`.
ServerCallFactoryImpl(
AsyncService &service,
RequestCallFunction<GrpcService, Request, Reply> request_call_function,
ServiceHandler &service_handler,
HandleRequestFunction<ServiceHandler, Request, Reply> handle_request_function,
const std::unique_ptr<grpc::ServerCompletionQueue> &cq)
: service_(service),
request_call_function_(request_call_function),
service_handler_(service_handler),
handle_request_function_(handle_request_function),
cq_(cq) {}
ServerCall *CreateCall() const override {
// Create a new `ServerCall`. This object will eventually be deleted by
// `GrpcServer::PollEventsFromCompletionQueue`.
auto call = new ServerCallImpl<ServiceHandler, Request, Reply>(
*this, service_handler_, handle_request_function_);
/// Request gRPC runtime to starting accepting this kind of request, using the call as
/// the tag.
(service_.*request_call_function_)(&call->context_, &call->request_,
&call->response_writer_, cq_.get(), cq_.get(),
call);
return call;
}
private:
/// The gRPC-generated `AsyncService`.
AsyncService &service_;
/// Pointer to the `AsyncService::RequestMethod` function.
RequestCallFunction<GrpcService, Request, Reply> request_call_function_;
/// The service handler that handles the request.
ServiceHandler &service_handler_;
/// Pointer to the service handler function.
HandleRequestFunction<ServiceHandler, Request, Reply> handle_request_function_;
/// The `CompletionQueue`.
const std::unique_ptr<grpc::ServerCompletionQueue> &cq_;
};
} // namespace rpc
} // namespace ray
#endif
+33
View File
@@ -0,0 +1,33 @@
#ifndef RAY_RPC_UTIL_H
#define RAY_RPC_UTIL_H
#include <grpcpp/grpcpp.h>
#include "ray/common/status.h"
namespace ray {
namespace rpc {
/// Helper function that converts a ray status to gRPC status.
inline grpc::Status RayStatusToGrpcStatus(const Status &ray_status) {
if (ray_status.ok()) {
return grpc::Status::OK;
} else {
// TODO(hchen): Use more specific error code.
return grpc::Status(grpc::StatusCode::UNKNOWN, ray_status.message());
}
}
/// Helper function that converts a gRPC status to ray status.
inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) {
if (grpc_status.ok()) {
return Status::OK();
} else {
return Status::IOError(grpc_status.error_message());
}
}
} // namespace rpc
} // namespace ray
#endif