diff --git a/.bazelrc b/.bazelrc
index 488b33101..3e3c3b6c4 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -2,3 +2,5 @@
build --compilation_mode=opt
build --action_env=PATH
build --action_env=PYTHON_BIN_PATH
+# This workaround is needed due to https://github.com/bazelbuild/bazel/issues/4341
+build --per_file_copt="external/com_github_grpc_grpc/.*@-DGRPC_BAZEL_BUILD"
diff --git a/.travis.yml b/.travis.yml
index 1888fa4ce..9a4fb66d8 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -126,7 +126,7 @@ matrix:
- ./ci/suppress_output ./ci/travis/install-dependencies.sh
# This command should be kept in sync with ray/python/README-building-wheels.md.
- - ./python/build-wheel-macos.sh
+ - ./ci/suppress_output ./python/build-wheel-macos.sh
script:
- if [ $RAY_CI_MACOS_WHEELS_AFFECTED != "1" ]; then exit; fi
diff --git a/BUILD.bazel b/BUILD.bazel
index 47e795011..da36eec0c 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -1,12 +1,37 @@
# Bazel build
# C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html
+load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library")
load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
load("@//bazel:ray.bzl", "flatbuffer_py_library")
load("@//bazel:cython_library.bzl", "pyx_library")
COPTS = ["-DRAY_USE_GLOG"]
+# Node manager gRPC lib.
+grpc_proto_library(
+ name = "node_manager_grpc_lib",
+ srcs = ["src/ray/protobuf/node_manager.proto"],
+)
+
+# Node manager server and client.
+cc_library(
+ name = "node_manager_rpc_lib",
+ srcs = glob([
+ "src/ray/rpc/*.cc",
+ ]),
+ hdrs = glob([
+ "src/ray/rpc/*.h",
+ ]),
+ copts = COPTS,
+ deps = [
+ ":node_manager_grpc_lib",
+ ":ray_common",
+ "@boost//:asio",
+ "@com_github_grpc_grpc//:grpc++",
+ ],
+)
+
cc_binary(
name = "raylet",
srcs = ["src/ray/raylet/main.cc"],
@@ -89,6 +114,7 @@ cc_library(
":gcs",
":gcs_fbs",
":node_manager_fbs",
+ ":node_manager_rpc_lib",
":object_manager",
":ray_common",
":ray_util",
diff --git a/bazel/ray_deps_build_all.bzl b/bazel/ray_deps_build_all.bzl
index 5598d5820..3e1e1838a 100644
--- a/bazel/ray_deps_build_all.bzl
+++ b/bazel/ray_deps_build_all.bzl
@@ -3,6 +3,8 @@ load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps")
load("@com_github_jupp0r_prometheus_cpp//:repositories.bzl", "prometheus_cpp_repositories")
load("@com_github_ray_project_ray//bazel:python_configure.bzl", "python_configure")
load("@com_github_checkstyle_java//:repo.bzl", "checkstyle_deps")
+load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
+
def ray_deps_build_all():
gen_java_deps()
@@ -10,3 +12,5 @@ def ray_deps_build_all():
boost_deps()
prometheus_cpp_repositories()
python_configure(name = "local_config_python")
+ grpc_deps()
+
diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl
index b3cd21b9b..e6dc21585 100644
--- a/bazel/ray_deps_setup.bzl
+++ b/bazel/ray_deps_setup.bzl
@@ -101,3 +101,11 @@ def ray_deps_setup():
# `https://github.com/jupp0r/prometheus-cpp/pull/225` getting merged.
urls = ["https://github.com/jovany-wang/prometheus-cpp/archive/master.zip"],
)
+
+ http_archive(
+ name = "com_github_grpc_grpc",
+ urls = [
+ "https://github.com/grpc/grpc/archive/7741e806a213cba63c96234f16d712a8aa101a49.tar.gz",
+ ],
+ strip_prefix = "grpc-7741e806a213cba63c96234f16d712a8aa101a49",
+ )
diff --git a/ci/suppress_output b/ci/suppress_output
index 623559d11..0f32b1a88 100755
--- a/ci/suppress_output
+++ b/ci/suppress_output
@@ -23,7 +23,7 @@ time "$@" >$TMPFILE 2>&1
CODE=$?
if [ $CODE != 0 ]; then
- cat $TMPFILE
+ tail -n 2000 $TMPFILE
echo "FAILED $CODE"
kill $WATCHDOG_PID
exit $CODE
diff --git a/ci/travis/install-bazel.sh b/ci/travis/install-bazel.sh
index c9614f772..5b6d95729 100755
--- a/ci/travis/install-bazel.sh
+++ b/ci/travis/install-bazel.sh
@@ -16,7 +16,7 @@ else
exit 1
fi
-URL="https://github.com/bazelbuild/bazel/releases/download/0.21.0/bazel-0.21.0-installer-${platform}-x86_64.sh"
+URL="https://github.com/bazelbuild/bazel/releases/download/0.26.1/bazel-0.26.1-installer-${platform}-x86_64.sh"
wget -O install.sh $URL
chmod +x install.sh
./install.sh --user
diff --git a/java/BUILD.bazel b/java/BUILD.bazel
index f3ae6f063..80ccabccf 100644
--- a/java/BUILD.bazel
+++ b/java/BUILD.bazel
@@ -94,11 +94,14 @@ define_java_module(
":org_ray_ray_api",
":org_ray_ray_runtime",
"@plasma//:org_apache_arrow_arrow_plasma",
+ "@maven//:com_google_guava_guava",
+ "@maven//:com_sun_xml_bind_jaxb_core",
+ "@maven//:com_sun_xml_bind_jaxb_impl",
+ "@maven//:commons_io_commons_io",
+ "@maven//:javax_xml_bind_jaxb_api",
"@maven//:org_apache_commons_commons_lang3",
"@maven//:org_slf4j_slf4j_api",
"@maven//:org_testng_testng",
- "@maven//:com_google_guava_guava",
- "@maven//:commons_io_commons_io",
],
)
diff --git a/java/test/pom.xml b/java/test/pom.xml
index 10f7ea4b3..6a3a31d20 100644
--- a/java/test/pom.xml
+++ b/java/test/pom.xml
@@ -32,11 +32,26 @@
guava
27.0.1-jre
+
+ com.sun.xml.bind
+ jaxb-core
+ 2.3.0
+
+
+ com.sun.xml.bind
+ jaxb-impl
+ 2.3.0
+
commons-io
commons-io
2.5
+
+ javax.xml.bind
+ jaxb-api
+ 2.3.0
+
org.apache.commons
commons-lang3
diff --git a/python/ray/tests/perf_integration_tests/test_perf_integration.py b/python/ray/tests/perf_integration_tests/test_perf_integration.py
index 2ce2a305a..ff34fe412 100644
--- a/python/ray/tests/perf_integration_tests/test_perf_integration.py
+++ b/python/ray/tests/perf_integration_tests/test_perf_integration.py
@@ -6,6 +6,7 @@ import numpy as np
import pytest
import ray
+from ray.tests.conftest import _ray_start_cluster
num_tasks_submitted = [10**n for n in range(0, 6)]
num_tasks_ids = ["{}_tasks".format(i) for i in num_tasks_submitted]
@@ -41,3 +42,25 @@ def test_task_submission(benchmark, num_tasks):
warmup()
benchmark(benchmark_task_submission, num_tasks)
ray.shutdown()
+
+
+def benchmark_task_forward(f, num_tasks):
+ ray.get([f.remote() for _ in range(num_tasks)])
+
+
+@pytest.mark.benchmark
+@pytest.mark.parametrize(
+ "num_tasks", [10**3, 10**4],
+ ids=[str(num) + "_tasks" for num in [10**3, 10**4]])
+def test_task_forward(benchmark, num_tasks):
+ with _ray_start_cluster(num_cpus=16, object_store_memory=10**6) as cluster:
+ cluster.add_node(resources={"my_resource": 100})
+ ray.init(redis_address=cluster.redis_address)
+
+ @ray.remote(resources={"my_resource": 0.001})
+ def f():
+ return 1
+
+ # Warm up
+ ray.get([f.remote() for _ in range(100)])
+ benchmark(benchmark_task_forward, f, num_tasks)
diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto
new file mode 100644
index 000000000..8a82da1c7
--- /dev/null
+++ b/src/ray/protobuf/node_manager.proto
@@ -0,0 +1,24 @@
+syntax = "proto3";
+
+package ray.rpc;
+
+message ForwardTaskRequest {
+ // The ID of the task to be forwarded.
+ bytes task_id = 1;
+ // The tasks in the uncommitted lineage of the forwarded task. This
+ // should include task_id.
+ // TODO(hchen): Currently, `uncommitted_tasks` are represented as
+ // flatbutters-serialized bytes. This is because the flatbuffers-defined Task data
+ // structure is being used in many places. We should move Task and all related data
+ // strucutres to protobuf.
+ repeated bytes uncommitted_tasks = 2;
+}
+
+message ForwardTaskReply {
+}
+
+// Service for inter-node-manager communication.
+service NodeManagerService {
+ // Forward a task and its uncommitted lineage to the remote node manager.
+ rpc ForwardTask(ForwardTaskRequest) returns (ForwardTaskReply);
+}
diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc
index 671a7a798..a0bde1ff0 100644
--- a/src/ray/raylet/node_manager.cc
+++ b/src/ray/raylet/node_manager.cc
@@ -99,9 +99,9 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
lineage_cache_(gcs_client_->client_table().GetLocalClientId(),
gcs_client_->raylet_task_table(), gcs_client_->raylet_task_table(),
config.max_lineage_size),
- remote_clients_(),
- remote_server_connections_(),
- actor_registry_() {
+ actor_registry_(),
+ node_manager_server_(config.node_manager_port, io_service, *this),
+ client_call_manager_(io_service) {
RAY_CHECK(heartbeat_period_.count() > 0);
// Initialize the resource map with own cluster resource configuration.
ClientID local_client_id = gcs_client_->client_table().GetLocalClientId();
@@ -117,6 +117,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
[this](const ObjectID &object_id) { HandleObjectMissing(object_id); }));
RAY_ARROW_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str()));
+ // Run the node manger rpc server.
+ node_manager_server_.Run();
}
ray::Status NodeManager::RegisterGcs() {
@@ -366,66 +368,24 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) {
return;
}
- // TODO(atumanov): make remote client lookup O(1)
- if (std::find(remote_clients_.begin(), remote_clients_.end(), client_id) ==
- remote_clients_.end()) {
- remote_clients_.push_back(client_id);
- } else {
- // NodeManager connection to this client was already established.
- RAY_LOG(DEBUG) << "Received a new client connection that already exists: "
+ auto entry = remote_node_manager_clients_.find(client_id);
+ if (entry != remote_node_manager_clients_.end()) {
+ RAY_LOG(DEBUG) << "Received notification of a new client that already exists: "
<< client_id;
return;
}
- // Establish a new NodeManager connection to this GCS client.
- auto status = ConnectRemoteNodeManager(client_id, client_data.node_manager_address,
- client_data.node_manager_port);
- if (!status.ok()) {
- // This is not a fatal error for raylet, but it should not happen.
- // We need to broadcase this message.
- std::string type = "raylet_connection_error";
- std::ostringstream error_message;
- error_message << "Failed to connect to ray node " << client_id
- << " with status: " << status.ToString()
- << ". This may be since the node was recently removed.";
- // We use the nil DriverID to broadcast the message to all drivers.
- RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(
- DriverID::Nil(), type, error_message.str(), current_time_ms()));
- return;
- }
+ // Initialize a rpc client to the new node manager.
+ std::unique_ptr client(
+ new rpc::NodeManagerClient(client_data.node_manager_address,
+ client_data.node_manager_port, client_call_manager_));
+ remote_node_manager_clients_.emplace(client_id, std::move(client));
ResourceSet resources_total(client_data.resources_total_label,
client_data.resources_total_capacity);
cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total));
}
-ray::Status NodeManager::ConnectRemoteNodeManager(const ClientID &client_id,
- const std::string &client_address,
- int32_t client_port) {
- // Establish a new NodeManager connection to this GCS client.
- RAY_LOG(INFO) << "[ConnectClient] Trying to connect to client " << client_id << " at "
- << client_address << ":" << client_port;
-
- boost::asio::ip::tcp::socket socket(io_service_);
- RAY_RETURN_NOT_OK(TcpConnect(socket, client_address, client_port));
-
- // The client is connected, now send a connect message to remote node manager.
- auto server_conn = TcpServerConnection::Create(std::move(socket));
-
- // Prepare client connection info buffer
- flatbuffers::FlatBufferBuilder fbb;
- auto message = protocol::CreateConnectClient(fbb, to_flatbuf(fbb, client_id_));
- fbb.Finish(message);
- // Send synchronously.
- // TODO(swang): Make this a WriteMessageAsync.
- RAY_RETURN_NOT_OK(server_conn->WriteMessage(
- static_cast(protocol::MessageType::ConnectClient), fbb.GetSize(),
- fbb.GetBufferPointer()));
-
- remote_server_connections_.emplace(client_id, std::move(server_conn));
- return ray::Status::OK();
-}
-
void NodeManager::ClientRemoved(const ClientTableDataT &client_data) {
// TODO(swang): If we receive a notification for our own death, clean up and
// exit immediately.
@@ -440,17 +400,13 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) {
// check that it is actually removed, or log a warning otherwise, but that may
// not be necessary.
- // Remove the client from the list of remote clients.
- std::remove(remote_clients_.begin(), remote_clients_.end(), client_id);
-
// Remove the client from the resource map.
cluster_resource_map_.erase(client_id);
- // Remove the remote server connection.
- const auto connection_entry = remote_server_connections_.find(client_id);
- if (connection_entry != remote_server_connections_.end()) {
- connection_entry->second->Close();
- remote_server_connections_.erase(connection_entry);
+ // Remove the node manager client.
+ const auto client_entry = remote_node_manager_clients_.find(client_id);
+ if (client_entry != remote_node_manager_clients_.end()) {
+ remote_node_manager_clients_.erase(client_entry);
} else {
RAY_LOG(WARNING) << "Received ClientRemoved callback for an unknown client "
<< client_id << ".";
@@ -1241,41 +1197,24 @@ void NodeManager::ProcessNewNodeManager(TcpClientConnection &node_manager_client
node_manager_client.ProcessMessages();
}
-void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_client,
- int64_t message_type,
- const uint8_t *message_data) {
- const auto message_type_value = static_cast(message_type);
- RAY_LOG(DEBUG) << "[NodeManager] Message "
- << protocol::EnumNameMessageType(message_type_value) << "("
- << message_type << ") from node manager";
- switch (message_type_value) {
- case protocol::MessageType::ConnectClient: {
- auto message = flatbuffers::GetRoot(message_data);
- auto client_id = from_flatbuf(*message->client_id());
- node_manager_client.SetClientID(client_id);
- } break;
- case protocol::MessageType::ForwardTaskRequest: {
- auto message = flatbuffers::GetRoot(message_data);
- TaskID task_id = from_flatbuf(*message->task_id());
-
- Lineage uncommitted_lineage(*message);
- const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData();
- RAY_LOG(DEBUG) << "Received forwarded task " << task.GetTaskSpecification().TaskId()
- << " on node " << gcs_client_->client_table().GetLocalClientId()
- << " spillback=" << task.GetTaskExecutionSpec().NumForwards();
- SubmitTask(task, uncommitted_lineage, /* forwarded = */ true);
- } break;
- case protocol::MessageType::DisconnectClient: {
- // TODO(rkn): We need to do some cleanup here.
- RAY_LOG(DEBUG) << "Received disconnect message from remote node manager. "
- << "We need to do some cleanup here.";
- // Do not process any more messages from this node manager.
- return;
- } break;
- default:
- RAY_LOG(FATAL) << "Received unexpected message type " << message_type;
+void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request,
+ rpc::ForwardTaskReply *reply,
+ rpc::RequestDoneCallback done_callback) {
+ // Get the forwarded task and its uncommitted lineage from the request.
+ TaskID task_id = TaskID::FromBinary(request.task_id());
+ Lineage uncommitted_lineage;
+ for (int i = 0; i < request.uncommitted_tasks_size(); i++) {
+ const std::string &task_message = request.uncommitted_tasks(i);
+ const Task task(*flatbuffers::GetRoot(
+ reinterpret_cast(task_message.data())));
+ RAY_CHECK(uncommitted_lineage.SetEntry(std::move(task), GcsStatus::UNCOMMITTED));
}
- node_manager_client.ProcessMessages();
+ const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData();
+ RAY_LOG(DEBUG) << "Received forwarded task " << task.GetTaskSpecification().TaskId()
+ << " on node " << gcs_client_->client_table().GetLocalClientId()
+ << " spillback=" << task.GetTaskExecutionSpec().NumForwards();
+ SubmitTask(task, uncommitted_lineage, /* forwarded = */ true);
+ done_callback(Status::OK());
}
void NodeManager::ProcessSetResourceRequest(
@@ -2253,6 +2192,16 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task,
void NodeManager::ForwardTask(
const Task &task, const ClientID &node_id,
const std::function &on_error) {
+ // Lookup node manager client for this node_id and use it to send the request.
+ auto client_entry = remote_node_manager_clients_.find(node_id);
+ if (client_entry == remote_node_manager_clients_.end()) {
+ // TODO(atumanov): caller must handle failure to ensure tasks are not lost.
+ RAY_LOG(INFO) << "No node manager client found for GCS client id " << node_id;
+ on_error(ray::Status::IOError("Node manager client not found"), task);
+ return;
+ }
+ auto &client = client_entry->second;
+
const auto &spec = task.GetTaskSpecification();
auto task_id = spec.TaskId();
@@ -2272,68 +2221,61 @@ void NodeManager::ForwardTask(
// Increment forward count for the forwarded task.
lineage_cache_entry_task.IncrementNumForwards();
- flatbuffers::FlatBufferBuilder fbb;
- auto request = uncommitted_lineage.ToFlatbuffer(fbb, task_id);
- fbb.Finish(request);
-
RAY_LOG(DEBUG) << "Forwarding task " << task_id << " from "
<< gcs_client_->client_table().GetLocalClientId() << " to " << node_id
<< " spillback="
<< lineage_cache_entry_task.GetTaskExecutionSpec().NumForwards();
- // Lookup remote server connection for this node_id and use it to send the request.
- auto it = remote_server_connections_.find(node_id);
- if (it == remote_server_connections_.end()) {
- // TODO(atumanov): caller must handle failure to ensure tasks are not lost.
- RAY_LOG(INFO) << "No NodeManager connection found for GCS client id " << node_id;
- on_error(ray::Status::IOError("NodeManager connection not found"), task);
- return;
+ // Prepare the request message.
+ rpc::ForwardTaskRequest request;
+ request.set_task_id(task_id.Binary());
+ for (auto &entry : uncommitted_lineage.GetEntries()) {
+ request.add_uncommitted_tasks(entry.second.TaskData().Serialize());
}
- auto &server_conn = it->second;
+
// Move the FORWARDING task to the SWAP queue so that we remember that we
// have it queued locally. Once the ForwardTaskRequest has been sent, the
// task will get re-queued, depending on whether the message succeeded or
// not.
local_queues_.QueueTasks({task}, TaskState::SWAP);
- server_conn->WriteMessageAsync(
- static_cast(protocol::MessageType::ForwardTaskRequest), fbb.GetSize(),
- fbb.GetBufferPointer(), [this, on_error, task_id, node_id](ray::Status status) {
- // Remove the FORWARDING task from the SWAP queue.
- TaskState state;
- const auto task = local_queues_.RemoveTask(task_id, &state);
- RAY_CHECK(state == TaskState::SWAP);
+ client->ForwardTask(request, [this, on_error, task_id, node_id](
+ Status status, const rpc::ForwardTaskReply &reply) {
+ // Remove the FORWARDING task from the SWAP queue.
+ TaskState state;
+ const auto task = local_queues_.RemoveTask(task_id, &state);
+ RAY_CHECK(state == TaskState::SWAP);
- if (status.ok()) {
- const auto &spec = task.GetTaskSpecification();
- // Mark as forwarded so that the task and its lineage are not
- // re-forwarded in the future to the receiving node.
- lineage_cache_.MarkTaskAsForwarded(task_id, node_id);
+ if (status.ok()) {
+ const auto &spec = task.GetTaskSpecification();
+ // Mark as forwarded so that the task and its lineage are not
+ // re-forwarded in the future to the receiving node.
+ lineage_cache_.MarkTaskAsForwarded(task_id, node_id);
- // Notify the task dependency manager that we are no longer responsible
- // for executing this task.
- task_dependency_manager_.TaskCanceled(task_id);
- // Preemptively push any local arguments to the receiving node. For now, we
- // only do this with actor tasks, since actor tasks must be executed by a
- // specific process and therefore have affinity to the receiving node.
- if (spec.IsActorTask()) {
- // Iterate through the object's arguments. NOTE(swang): We do not include
- // the execution dependencies here since those cannot be transferred
- // between nodes.
- for (int i = 0; i < spec.NumArgs(); ++i) {
- int count = spec.ArgIdCount(i);
- for (int j = 0; j < count; j++) {
- ObjectID argument_id = spec.ArgId(i, j);
- // If the argument is local, then push it to the receiving node.
- if (task_dependency_manager_.CheckObjectLocal(argument_id)) {
- object_manager_.Push(argument_id, node_id);
- }
- }
+ // Notify the task dependency manager that we are no longer responsible
+ // for executing this task.
+ task_dependency_manager_.TaskCanceled(task_id);
+ // Preemptively push any local arguments to the receiving node. For now, we
+ // only do this with actor tasks, since actor tasks must be executed by a
+ // specific process and therefore have affinity to the receiving node.
+ if (spec.IsActorTask()) {
+ // Iterate through the object's arguments. NOTE(swang): We do not include
+ // the execution dependencies here since those cannot be transferred
+ // between nodes.
+ for (int i = 0; i < spec.NumArgs(); ++i) {
+ int count = spec.ArgIdCount(i);
+ for (int j = 0; j < count; j++) {
+ ObjectID argument_id = spec.ArgId(i, j);
+ // If the argument is local, then push it to the receiving node.
+ if (task_dependency_manager_.CheckObjectLocal(argument_id)) {
+ object_manager_.Push(argument_id, node_id);
}
}
- } else {
- on_error(status, task);
}
- });
+ }
+ } else {
+ on_error(status, task);
+ }
+ });
}
void NodeManager::DumpDebugState() const {
@@ -2368,10 +2310,11 @@ std::string NodeManager::DebugString() const {
result << "\n- num dead actors: " << statistical_data.dead_actors;
result << "\n- max num handles: " << statistical_data.max_num_handles;
- result << "\nRemoteConnections:";
- for (auto &pair : remote_server_connections_) {
- result << "\n" << pair.first.Hex() << ": " << pair.second->DebugString();
+ result << "\nRemote node manager clients: ";
+ for (const auto &entry : remote_node_manager_clients_) {
+ result << "\n" << entry.first;
}
+
result << "\nDebugString() time ms: " << (current_time_ms() - now_ms);
return result.str();
}
diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h
index 3f7e4d7da..616133583 100644
--- a/src/ray/raylet/node_manager.h
+++ b/src/ray/raylet/node_manager.h
@@ -4,6 +4,9 @@
#include
// clang-format off
+#include "ray/rpc/client_call.h"
+#include "ray/rpc/node_manager_server.h"
+#include "ray/rpc/node_manager_client.h"
#include "ray/raylet/task.h"
#include "ray/object_manager/object_manager.h"
#include "ray/common/client_connection.h"
@@ -52,7 +55,7 @@ struct NodeManagerConfig {
std::string session_dir;
};
-class NodeManager {
+class NodeManager : public rpc::NodeManagerServiceHandler {
public:
/// Create a node manager.
///
@@ -86,15 +89,6 @@ class NodeManager {
/// \return Void.
void ProcessNewNodeManager(TcpClientConnection &node_manager_client);
- /// Handle a message from a remote node manager.
- ///
- /// \param node_manager_client The connection to the remote node manager.
- /// \param message_type The type of the message.
- /// \param message The message contents.
- /// \return Void.
- void ProcessNodeManagerMessage(TcpClientConnection &node_manager_client,
- int64_t message_type, const uint8_t *message);
-
/// Subscribe to the relevant GCS tables and set up handlers.
///
/// \return Status indicating whether this was done successfully or not.
@@ -108,6 +102,9 @@ class NodeManager {
/// Record metrics.
void RecordMetrics() const;
+ /// Get the port of the node manager rpc server.
+ int GetServerPort() const { return node_manager_server_.GetPort(); }
+
private:
/// Methods for handling clients.
@@ -450,15 +447,10 @@ class NodeManager {
void HandleDisconnectedActor(const ActorID &actor_id, bool was_local,
bool intentional_disconnect);
- /// connect to a remote node manager.
- ///
- /// \param client_id The client ID for the remote node manager.
- /// \param client_address The IP address for the remote node manager.
- /// \param client_port The listening port for the remote node manager.
- /// \return True if the connect succeeds.
- ray::Status ConnectRemoteNodeManager(const ClientID &client_id,
- const std::string &client_address,
- int32_t client_port);
+ /// Handle a `ForwardTask` request.
+ void HandleForwardTask(const rpc::ForwardTaskRequest &request,
+ rpc::ForwardTaskReply *reply,
+ rpc::RequestDoneCallback done_callback) override;
// GCS client ID for this node.
ClientID client_id_;
@@ -505,9 +497,6 @@ class NodeManager {
TaskDependencyManager task_dependency_manager_;
/// The lineage cache for the GCS object and task tables.
LineageCache lineage_cache_;
- std::vector remote_clients_;
- std::unordered_map>
- remote_server_connections_;
/// A mapping from actor ID to registration information about that actor
/// (including which node manager owns it).
std::unordered_map actor_registry_;
@@ -515,6 +504,16 @@ class NodeManager {
/// This map stores actor ID to the ID of the checkpoint that will be used to
/// restore the actor.
std::unordered_map checkpoint_id_to_restore_;
+
+ /// The RPC server.
+ rpc::NodeManagerServer node_manager_server_;
+
+ /// The `ClientCallManager` object that is shared by all `NodeManagerClient`s.
+ rpc::ClientCallManager client_call_manager_;
+
+ /// Map from node ids to clients of the remote node managers.
+ std::unordered_map>
+ remote_node_manager_clients_;
};
} // namespace raylet
diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc
index 80630d372..473e6c263 100644
--- a/src/ray/raylet/raylet.cc
+++ b/src/ray/raylet/raylet.cc
@@ -61,15 +61,10 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_
main_service,
boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(),
object_manager_config.object_manager_port)),
- object_manager_socket_(main_service),
- node_manager_acceptor_(main_service, boost::asio::ip::tcp::endpoint(
- boost::asio::ip::tcp::v4(),
- node_manager_config.node_manager_port)),
- node_manager_socket_(main_service) {
+ object_manager_socket_(main_service) {
// Start listening for clients.
DoAccept();
DoAcceptObjectManager();
- DoAcceptNodeManager();
RAY_CHECK_OK(RegisterGcs(
node_ip_address, socket_name_, object_manager_config.store_socket_name,
@@ -100,7 +95,7 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address,
client_info.raylet_socket_name = raylet_socket_name;
client_info.object_store_socket_name = object_store_socket_name;
client_info.object_manager_port = object_manager_acceptor_.local_endpoint().port();
- client_info.node_manager_port = node_manager_acceptor_.local_endpoint().port();
+ client_info.node_manager_port = node_manager_.GetServerPort();
// Add resource information.
for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) {
client_info.resources_total_label.push_back(resource_pair.first);
@@ -120,33 +115,6 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address,
return Status::OK();
}
-void Raylet::DoAcceptNodeManager() {
- node_manager_acceptor_.async_accept(node_manager_socket_,
- boost::bind(&Raylet::HandleAcceptNodeManager, this,
- boost::asio::placeholders::error));
-}
-
-void Raylet::HandleAcceptNodeManager(const boost::system::error_code &error) {
- if (!error) {
- ClientHandler client_handler =
- [this](TcpClientConnection &client) {
- node_manager_.ProcessNewNodeManager(client);
- };
- MessageHandler message_handler =
- [this](std::shared_ptr client, int64_t message_type,
- const uint8_t *message) {
- node_manager_.ProcessNodeManagerMessage(*client, message_type, message);
- };
- // Accept a new TCP client and dispatch it to the node manager.
- auto new_connection = TcpClientConnection::Create(
- client_handler, message_handler, std::move(node_manager_socket_), "node manager",
- node_manager_message_enum,
- static_cast(protocol::MessageType::DisconnectClient));
- }
- // We're ready to accept another client.
- DoAcceptNodeManager();
-}
-
void Raylet::DoAcceptObjectManager() {
object_manager_acceptor_.async_accept(
object_manager_socket_, boost::bind(&Raylet::HandleAcceptObjectManager, this,
diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h
index 84274ea6e..26fe74b2b 100644
--- a/src/ray/raylet/raylet.h
+++ b/src/ray/raylet/raylet.h
@@ -63,8 +63,6 @@ class Raylet {
void DoAcceptObjectManager();
/// Handle an accepted tcp client connection.
void HandleAcceptObjectManager(const boost::system::error_code &error);
- void DoAcceptNodeManager();
- void HandleAcceptNodeManager(const boost::system::error_code &error);
friend class TestObjectManagerIntegration;
@@ -88,10 +86,6 @@ class Raylet {
boost::asio::ip::tcp::acceptor object_manager_acceptor_;
/// The socket to listen on for new object manager tcp clients.
boost::asio::ip::tcp::socket object_manager_socket_;
- /// An acceptor for new tcp clients.
- boost::asio::ip::tcp::acceptor node_manager_acceptor_;
- /// The socket to listen on for new tcp clients.
- boost::asio::ip::tcp::socket node_manager_socket_;
};
} // namespace raylet
diff --git a/src/ray/raylet/task.cc b/src/ray/raylet/task.cc
index 5d6a02186..9d8036411 100644
--- a/src/ray/raylet/task.cc
+++ b/src/ray/raylet/task.cc
@@ -46,14 +46,18 @@ void Task::CopyTaskExecutionSpec(const Task &task) {
ComputeDependencies();
}
+const std::string Task::Serialize() const {
+ flatbuffers::FlatBufferBuilder fbb;
+ fbb.Finish(ToFlatbuffer(fbb));
+ return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize());
+}
+
std::string SerializeTaskAsString(const std::vector *dependencies,
const TaskSpecification *task_spec) {
- flatbuffers::FlatBufferBuilder fbb;
std::vector execution_dependencies(*dependencies);
TaskExecutionSpecification execution_spec(std::move(execution_dependencies));
Task task(execution_spec, *task_spec);
- fbb.Finish(task.ToFlatbuffer(fbb));
- return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize());
+ return task.Serialize();
}
} // namespace raylet
diff --git a/src/ray/raylet/task.h b/src/ray/raylet/task.h
index b942e2bf2..10cdfe511 100644
--- a/src/ray/raylet/task.h
+++ b/src/ray/raylet/task.h
@@ -84,6 +84,9 @@ class Task {
/// \param task Task structure with updated dynamic information.
void CopyTaskExecutionSpec(const Task &task);
+ /// Serialize this task as a string.
+ const std::string Serialize() const;
+
private:
void ComputeDependencies();
diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h
new file mode 100644
index 000000000..725652cb5
--- /dev/null
+++ b/src/ray/rpc/client_call.h
@@ -0,0 +1,169 @@
+#ifndef RAY_RPC_CLIENT_CALL_H
+#define RAY_RPC_CLIENT_CALL_H
+
+#include
+#include
+
+#include "ray/common/status.h"
+#include "ray/rpc/util.h"
+
+namespace ray {
+namespace rpc {
+
+/// Represents an outgoing gRPC request.
+///
+/// The lifecycle of a `ClientCall` is as follows.
+///
+/// When a client submits a new gRPC request, a new `ClientCall` object will be created
+/// by `ClientCallMangager::CreateCall`. Then the object will be used as the tag of
+/// `CompletionQueue`.
+///
+/// When the reply is received, `ClientCallMangager` will get the address of this object
+/// via `CompletionQueue`'s tag. And the manager should call `OnReplyReceived` and then
+/// delete this object.
+///
+/// NOTE(hchen): Compared to `ClientCallImpl`, this abstract interface doesn't use
+/// template. This allows the users (e.g., `ClientCallMangager`) not having to use
+/// template as well.
+class ClientCall {
+ public:
+ /// The callback to be called by `ClientCallManager` when the reply of this request is
+ /// received.
+ virtual void OnReplyReceived() = 0;
+};
+
+class ClientCallManager;
+
+/// Reprents the client callback function of a particular rpc method.
+///
+/// \tparam Reply Type of the reply message.
+template
+using ClientCallback = std::function;
+
+/// Implementaion of the `ClientCall`. It represents a `ClientCall` for a particular
+/// RPC method.
+///
+/// \tparam Reply Type of the Reply message.
+template
+class ClientCallImpl : public ClientCall {
+ public:
+ void OnReplyReceived() override { callback_(GrpcStatusToRayStatus(status_), reply_); }
+
+ private:
+ /// Constructor.
+ ///
+ /// \param[in] callback The callback function to handle the reply.
+ ClientCallImpl(const ClientCallback &callback) : callback_(callback) {}
+
+ /// The reply message.
+ Reply reply_;
+
+ /// The callback function to handle the reply.
+ ClientCallback callback_;
+
+ /// The response reader.
+ std::unique_ptr> response_reader_;
+
+ /// gRPC status of this request.
+ grpc::Status status_;
+
+ /// Context for the client. It could be used to convey extra information to
+ /// the server and/or tweak certain RPC behaviors.
+ grpc::ClientContext context_;
+
+ friend class ClientCallManager;
+};
+
+/// Peprents the generic signature of a `FooService::Stub::PrepareAsyncBar`
+/// function, where `Foo` is the service name and `Bar` is the rpc method name.
+///
+/// \tparam GrpcService Type of the gRPC-generated service class.
+/// \tparam Request Type of the request message.
+/// \tparam Reply Type of the reply message.
+template
+using PrepareAsyncFunction = std::unique_ptr> (
+ GrpcService::Stub::*)(grpc::ClientContext *context, const Request &request,
+ grpc::CompletionQueue *cq);
+
+/// `ClientCallManager` is used to manage outgoing gRPC requests and the lifecycles of
+/// `ClientCall` objects.
+///
+/// It maintains a thread that keeps polling events from `CompletionQueue`, and post
+/// the callback function to the main event loop when a reply is received.
+///
+/// Mutiple clients can share one `ClientCallManager`.
+class ClientCallManager {
+ public:
+ /// Constructor.
+ ///
+ /// \param[in] main_service The main event loop, to which the callback functions will be
+ /// posted.
+ ClientCallManager(boost::asio::io_service &main_service) : main_service_(main_service) {
+ // Start the polling thread.
+ std::thread polling_thread(&ClientCallManager::PollEventsFromCompletionQueue, this);
+ polling_thread.detach();
+ }
+
+ ~ClientCallManager() { cq_.Shutdown(); }
+
+ /// Create a new `ClientCall` and send request.
+ ///
+ /// \param[in] stub The gRPC-generated stub.
+ /// \param[in] prepare_async_function Pointer to the gRPC-generated
+ /// `FooService::Stub::PrepareAsyncBar` function.
+ /// \param[in] request The request message.
+ /// \param[in] callback The callback function that handles reply.
+ ///
+ /// \tparam GrpcService Type of the gRPC-generated service class.
+ /// \tparam Request Type of the request message.
+ /// \tparam Reply Type of the reply message.
+ template
+ ClientCall *CreateCall(
+ typename GrpcService::Stub &stub,
+ const PrepareAsyncFunction prepare_async_function,
+ const Request &request, const ClientCallback &callback) {
+ // Create a new `ClientCall` object. This object will eventuall be deleted in the
+ // `ClientCallManager::PollEventsFromCompletionQueue` when reply is received.
+ auto call = new ClientCallImpl(callback);
+ // Send request.
+ call->response_reader_ =
+ (stub.*prepare_async_function)(&call->context_, request, &cq_);
+ call->response_reader_->StartCall();
+ call->response_reader_->Finish(&call->reply_, &call->status_, (void *)call);
+ return call;
+ }
+
+ private:
+ /// This function runs in a background thread. It keeps polling events from the
+ /// `CompletionQueue`, and dispaches the event to the callbacks via the `ClientCall`
+ /// objects.
+ void PollEventsFromCompletionQueue() {
+ void *got_tag;
+ bool ok = false;
+ // Keep reading events from the `CompletionQueue` until it's shutdown.
+ while (cq_.Next(&got_tag, &ok)) {
+ ClientCall *call = reinterpret_cast(got_tag);
+ if (ok) {
+ // Post the callback to the main event loop.
+ main_service_.post([call]() {
+ call->OnReplyReceived();
+ // The call is finished, we can delete the `ClientCall` object now.
+ delete call;
+ });
+ } else {
+ delete call;
+ }
+ }
+ }
+
+ /// The main event loop, to which the callback functions will be posted.
+ boost::asio::io_service &main_service_;
+
+ /// The gRPC `CompletionQueue` object used to poll events.
+ grpc::CompletionQueue cq_;
+};
+
+} // namespace rpc
+} // namespace ray
+
+#endif
diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc
new file mode 100644
index 000000000..feb788da7
--- /dev/null
+++ b/src/ray/rpc/grpc_server.cc
@@ -0,0 +1,70 @@
+#include "ray/rpc/grpc_server.h"
+
+namespace ray {
+namespace rpc {
+
+void GrpcServer::Run() {
+ std::string server_address("0.0.0.0:" + std::to_string(port_));
+
+ grpc::ServerBuilder builder;
+ // TODO(hchen): Add options for authentication.
+ builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_);
+ // Allow subclasses to register concrete services.
+ RegisterServices(builder);
+ // Get hold of the completion queue used for the asynchronous communication
+ // with the gRPC runtime.
+ cq_ = builder.AddCompletionQueue();
+ // Build and start server.
+ server_ = builder.BuildAndStart();
+ RAY_LOG(DEBUG) << name_ << " server started, listening on port " << port_ << ".";
+
+ // Allow subclasses to initialize the server call factories.
+ InitServerCallFactories(&server_call_factories_and_concurrencies_);
+ for (auto &entry : server_call_factories_and_concurrencies_) {
+ for (int i = 0; i < entry.second; i++) {
+ // Create and request calls from the factory.
+ entry.first->CreateCall();
+ }
+ }
+ // Start a thread that polls incoming requests.
+ std::thread polling_thread(&GrpcServer::PollEventsFromCompletionQueue, this);
+ polling_thread.detach();
+}
+
+void GrpcServer::PollEventsFromCompletionQueue() {
+ void *tag;
+ bool ok;
+ // Keep reading events from the `CompletionQueue` until it's shutdown.
+ while (cq_->Next(&tag, &ok)) {
+ ServerCall *server_call = static_cast(tag);
+ // `ok == false` indicates that the server has been shut down.
+ // We should delete the call object in this case.
+ bool delete_call = !ok;
+ if (ok) {
+ switch (server_call->GetState()) {
+ case ServerCallState::PENDING:
+ // We've received a new incoming request. Now this call object is used to
+ // track this request. So we need to create another call to handle next
+ // incoming request.
+ server_call->GetFactory().CreateCall();
+ server_call->SetState(ServerCallState::PROCESSING);
+ main_service_.post([server_call] { server_call->HandleRequest(); });
+ break;
+ case ServerCallState::SENDING_REPLY:
+ // The reply has been sent, this call can be deleted now.
+ // This event is triggered by `ServerCallImpl::SendReply`.
+ delete_call = true;
+ break;
+ default:
+ RAY_LOG(FATAL) << "Shouldn't reach here.";
+ break;
+ }
+ }
+ if (delete_call) {
+ delete server_call;
+ }
+ }
+}
+
+} // namespace rpc
+} // namespace ray
diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h
new file mode 100644
index 000000000..4953f4706
--- /dev/null
+++ b/src/ray/rpc/grpc_server.h
@@ -0,0 +1,92 @@
+#ifndef RAY_RPC_GRPC_SERVER_H
+#define RAY_RPC_GRPC_SERVER_H
+
+#include
+
+#include
+#include
+
+#include "ray/common/status.h"
+#include "ray/rpc/server_call.h"
+
+namespace ray {
+namespace rpc {
+
+/// Base class that represents an abstract gRPC server.
+///
+/// A `GrpcServer` listens on a specific port. It owns
+/// 1) a `ServerCompletionQueue` that is used for polling events from gRPC,
+/// 2) and a thread that polls events from the `ServerCompletionQueue`.
+///
+/// Subclasses can register one or multiple services to a `GrpcServer`, see
+/// `RegisterServices`. And they should also implement `InitServerCallFactories` to decide
+/// which kinds of requests this server should accept.
+class GrpcServer {
+ public:
+ /// Constructor.
+ ///
+ /// \param[in] name Name of this server, used for logging and debugging purpose.
+ /// \param[in] port The port to bind this server to. If it's 0, a random available port
+ /// will be chosen.
+ /// \param[in] main_service The main event loop, to which service handler functions
+ /// will be posted.
+ GrpcServer(const std::string &name, const uint32_t port,
+ boost::asio::io_service &main_service)
+ : name_(name), port_(port), main_service_(main_service) {}
+
+ /// Destruct this gRPC server.
+ ~GrpcServer() {
+ server_->Shutdown();
+ cq_->Shutdown();
+ }
+
+ /// Initialize and run this server.
+ void Run();
+
+ /// Get the port of this gRPC server.
+ int GetPort() const { return port_; }
+
+ protected:
+ /// Subclasses should implement this method and register one or multiple gRPC services
+ /// to the given `ServerBuilder`.
+ ///
+ /// \param[in] builder The `ServerBuilder` instance to register services to.
+ virtual void RegisterServices(grpc::ServerBuilder &builder) = 0;
+
+ /// Subclasses should implement this method to initialize the `ServerCallFactory`
+ /// instances, as well as specify maximum number of concurrent requests that gRPC
+ /// server can "accept" (not "handle"). Each factory will be used to create
+ /// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and
+ /// handle an incoming request.
+ ///
+ /// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects,
+ /// and the maximum number of concurrent requests that gRPC server can accept.
+ virtual void InitServerCallFactories(
+ std::vector, int>>
+ *server_call_factories_and_concurrencies) = 0;
+
+ /// This function runs in a background thread. It keeps polling events from the
+ /// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances
+ /// via the `ServerCall` objects.
+ void PollEventsFromCompletionQueue();
+
+ /// The main event loop, to which the service handler functions will be posted.
+ boost::asio::io_service &main_service_;
+ /// Name of this server, used for logging and debugging purpose.
+ const std::string name_;
+ /// Port of this server.
+ int port_;
+ /// The `ServerCallFactory` objects, and the maximum number of concurrent requests that
+ /// gRPC server can accept.
+ std::vector, int>>
+ server_call_factories_and_concurrencies_;
+ /// The `ServerCompletionQueue` object used for polling events.
+ std::unique_ptr cq_;
+ /// The `Server` object.
+ std::unique_ptr server_;
+};
+
+} // namespace rpc
+} // namespace ray
+
+#endif
diff --git a/src/ray/rpc/node_manager_client.h b/src/ray/rpc/node_manager_client.h
new file mode 100644
index 000000000..005c75db4
--- /dev/null
+++ b/src/ray/rpc/node_manager_client.h
@@ -0,0 +1,56 @@
+#ifndef RAY_RPC_NODE_MANAGER_CLIENT_H
+#define RAY_RPC_NODE_MANAGER_CLIENT_H
+
+#include
+
+#include
+
+#include "ray/common/status.h"
+#include "ray/rpc/client_call.h"
+#include "ray/util/logging.h"
+#include "src/ray/protobuf/node_manager.grpc.pb.h"
+#include "src/ray/protobuf/node_manager.pb.h"
+
+namespace ray {
+namespace rpc {
+
+/// Client used for communicating with a remote node manager server.
+class NodeManagerClient {
+ public:
+ /// Constructor.
+ ///
+ /// \param[in] address Address of the node manager server.
+ /// \param[in] port Port of the node manager server.
+ /// \param[in] client_call_manager The `ClientCallManager` used for managing requests.
+ NodeManagerClient(const std::string &address, const int port,
+ ClientCallManager &client_call_manager)
+ : client_call_manager_(client_call_manager) {
+ std::shared_ptr channel = grpc::CreateChannel(
+ address + ":" + std::to_string(port), grpc::InsecureChannelCredentials());
+ stub_ = NodeManagerService::NewStub(channel);
+ };
+
+ /// Forward a task and its uncommitted lineage.
+ ///
+ /// \param[in] request The request message.
+ /// \param[in] callback The callback function that handles reply.
+ void ForwardTask(const ForwardTaskRequest &request,
+ const ClientCallback &callback) {
+ client_call_manager_
+ .CreateCall(
+ *stub_, &NodeManagerService::Stub::PrepareAsyncForwardTask, request,
+ callback);
+ }
+
+ private:
+ /// The gRPC-generated stub.
+ std::unique_ptr stub_;
+
+ /// The `ClientCallManager` used for managing requests.
+ ClientCallManager &client_call_manager_;
+};
+
+} // namespace rpc
+} // namespace ray
+
+#endif // RAY_RPC_NODE_MANAGER_CLIENT_H
diff --git a/src/ray/rpc/node_manager_server.h b/src/ray/rpc/node_manager_server.h
new file mode 100644
index 000000000..afaea299e
--- /dev/null
+++ b/src/ray/rpc/node_manager_server.h
@@ -0,0 +1,71 @@
+#ifndef RAY_RPC_NODE_MANAGER_SERVER_H
+#define RAY_RPC_NODE_MANAGER_SERVER_H
+
+#include "ray/rpc/grpc_server.h"
+#include "ray/rpc/server_call.h"
+
+#include "src/ray/protobuf/node_manager.grpc.pb.h"
+#include "src/ray/protobuf/node_manager.pb.h"
+
+namespace ray {
+namespace rpc {
+
+/// Interface of the `NodeManagerService`, see `src/ray/protobuf/node_manager.proto`.
+class NodeManagerServiceHandler {
+ public:
+ /// Handle a `ForwardTask` request.
+ /// The implementation can handle this request asynchronously. When hanling is done, the
+ /// `done_callback` should be called.
+ ///
+ /// \param[in] request The request message.
+ /// \param[out] reply The reply message.
+ /// \param[in] done_callback The callback to be called when the request is done.
+ virtual void HandleForwardTask(const ForwardTaskRequest &request,
+ ForwardTaskReply *reply,
+ RequestDoneCallback done_callback) = 0;
+};
+
+/// The `GrpcServer` for `NodeManagerService`.
+class NodeManagerServer : public GrpcServer {
+ public:
+ /// Constructor.
+ ///
+ /// \param[in] port See super class.
+ /// \param[in] main_service See super class.
+ /// \param[in] handler The service handler that actually handle the requests.
+ NodeManagerServer(const uint32_t port, boost::asio::io_service &main_service,
+ NodeManagerServiceHandler &service_handler)
+ : GrpcServer("NodeManager", port, main_service),
+ service_handler_(service_handler){};
+
+ void RegisterServices(grpc::ServerBuilder &builder) override {
+ /// Register `NodeManagerService`.
+ builder.RegisterService(&service_);
+ }
+
+ void InitServerCallFactories(
+ std::vector, int>>
+ *server_call_factories_and_concurrencies) override {
+ // Initialize the factory for `ForwardTask` requests.
+ std::unique_ptr forward_task_call_factory(
+ new ServerCallFactoryImpl(
+ service_, &NodeManagerService::AsyncService::RequestForwardTask,
+ service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq_));
+
+ // Set `ForwardTask`'s accept concurrency to 100.
+ server_call_factories_and_concurrencies->emplace_back(
+ std::move(forward_task_call_factory), 100);
+ }
+
+ private:
+ /// The grpc async service object.
+ NodeManagerService::AsyncService service_;
+ /// The service handler that actually handle the requests.
+ NodeManagerServiceHandler &service_handler_;
+};
+
+} // namespace rpc
+} // namespace ray
+
+#endif
diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h
new file mode 100644
index 000000000..e06278260
--- /dev/null
+++ b/src/ray/rpc/server_call.h
@@ -0,0 +1,233 @@
+#ifndef RAY_RPC_SERVER_CALL_H
+#define RAY_RPC_SERVER_CALL_H
+
+#include
+
+#include "ray/common/status.h"
+#include "ray/rpc/util.h"
+
+namespace ray {
+namespace rpc {
+
+/// Represents the callback function to be called when a `ServiceHandler` finishes
+/// handling a request.
+using RequestDoneCallback = std::function;
+
+/// Represents state of a `ServerCall`.
+enum class ServerCallState {
+ /// The call is created and waiting for an incoming request.
+ PENDING,
+ /// Request is received and being processed.
+ PROCESSING,
+ /// Request processing is done, and reply is being sent to client.
+ SENDING_REPLY
+};
+
+class ServerCallFactory;
+
+/// Reprensents an incoming request of a gRPC server.
+///
+/// The lifecycle and state transition of a `ServerCall` is as follows:
+///
+/// --(1)--> PENDING --(2)--> PROCESSING --(3)--> SENDING_REPLY --(4)--> [FINISHED]
+///
+/// (1) The `GrpcServer` creates a `ServerCall` and use it as the tag to accept requests
+/// gRPC `CompletionQueue`. Now the state is `PENDING`.
+/// (2) When a request is received, an event will be gotten from the `CompletionQueue`.
+/// `GrpcServer` then should change `ServerCall`'s state to PROCESSING and call
+/// `ServerCall::HandleRequest`.
+/// (3) When the `ServiceHandler` finishes handling the request, `ServerCallImpl::Finish`
+/// will be called, and the state becomes `SENDING_REPLY`.
+/// (4) When the reply is sent, an event will be gotten from the `CompletionQueue`.
+/// `GrpcServer` will then delete this call.
+///
+/// NOTE(hchen): Compared to `ServerCallImpl`, this abstract interface doesn't use
+/// template. This allows the users (e.g., `GrpcServer`) not having to use
+/// template as well.
+class ServerCall {
+ public:
+ /// Get the state of this `ServerCall`.
+ virtual ServerCallState GetState() const = 0;
+
+ /// Set state of this `ServerCall`.
+ virtual void SetState(const ServerCallState &new_state) = 0;
+
+ /// Handle the requst. This is the callback function to be called by
+ /// `GrpcServer` when the request is received.
+ virtual void HandleRequest() = 0;
+
+ /// Get the factory that created this `ServerCall`.
+ virtual const ServerCallFactory &GetFactory() const = 0;
+};
+
+/// The factory that creates a particular kind of `ServerCall` objects.
+class ServerCallFactory {
+ public:
+ /// Create a new `ServerCall` and request gRPC runtime to start accepting the
+ /// corresonding type of requests.
+ ///
+ /// \return Pointer to the `ServerCall` object.
+ virtual ServerCall *CreateCall() const = 0;
+};
+
+/// Represents the generic signature of a `FooServiceHandler::HandleBar()`
+/// function, where `Foo` is the service name and `Bar` is the rpc method name.
+///
+/// \tparam ServiceHandler Type of the handler that handles the request.
+/// \tparam Request Type of the request message.
+/// \tparam Reply Type of the reply message.
+template
+using HandleRequestFunction = void (ServiceHandler::*)(const Request &, Reply *,
+ RequestDoneCallback);
+
+/// Implementation of `ServerCall`. It represents `ServerCall` for a particular
+/// RPC method.
+///
+/// \tparam ServiceHandler Type of the handler that handles the request.
+/// \tparam Request Type of the request message.
+/// \tparam Reply Type of the reply message.
+template
+class ServerCallImpl : public ServerCall {
+ public:
+ /// Constructor.
+ ///
+ /// \param[in] factory The factory which created this call.
+ /// \param[in] service_handler The service handler that handles the request.
+ /// \param[in] handle_request_function Pointer to the service handler function.
+ ServerCallImpl(
+ const ServerCallFactory &factory, ServiceHandler &service_handler,
+ HandleRequestFunction handle_request_function)
+ : state_(ServerCallState::PENDING),
+ factory_(factory),
+ service_handler_(service_handler),
+ handle_request_function_(handle_request_function),
+ response_writer_(&context_) {}
+
+ ServerCallState GetState() const override { return state_; }
+
+ void SetState(const ServerCallState &new_state) override { state_ = new_state; }
+
+ void HandleRequest() override {
+ state_ = ServerCallState::PROCESSING;
+ (service_handler_.*handle_request_function_)(request_, &reply_,
+ [this](Status status) {
+ // When the handler is done with the
+ // request, tell gRPC to finish this
+ // request.
+ SendReply(status);
+ });
+ }
+
+ const ServerCallFactory &GetFactory() const override { return factory_; }
+
+ private:
+ /// Tell gRPC to finish this request.
+ void SendReply(Status status) {
+ state_ = ServerCallState::SENDING_REPLY;
+ response_writer_.Finish(reply_, RayStatusToGrpcStatus(status), this);
+ }
+
+ /// State of this call.
+ ServerCallState state_;
+
+ /// The factory which created this call.
+ const ServerCallFactory &factory_;
+
+ /// The service handler that handles the request.
+ ServiceHandler &service_handler_;
+
+ /// Pointer to the service handler function.
+ HandleRequestFunction handle_request_function_;
+
+ /// Context for the request, allowing to tweak aspects of it such as the use
+ /// of compression, authentication, as well as to send metadata back to the client.
+ grpc::ServerContext context_;
+
+ /// The reponse writer.
+ grpc::ServerAsyncResponseWriter response_writer_;
+
+ /// The request message.
+ Request request_;
+
+ /// The reply message.
+ Reply reply_;
+
+ template
+ friend class ServerCallFactoryImpl;
+};
+
+/// Represents the generic signature of a `FooService::AsyncService::RequestBar()`
+/// function, where `Foo` is the service name and `Bar` is the rpc method name.
+/// \tparam GrpcService Type of the gRPC-generated service class.
+/// \tparam Request Type of the request message.
+/// \tparam Reply Type of the reply message.
+template
+using RequestCallFunction = void (GrpcService::AsyncService::*)(
+ grpc::ServerContext *, Request *, grpc::ServerAsyncResponseWriter *,
+ grpc::CompletionQueue *, grpc::ServerCompletionQueue *, void *);
+
+/// Implementation of `ServerCallFactory`
+///
+/// \tparam GrpcService Type of the gRPC-generated service class.
+/// \tparam ServiceHandler Type of the handler that handles the request.
+/// \tparam Request Type of the request message.
+/// \tparam Reply Type of the reply message.
+template
+class ServerCallFactoryImpl : public ServerCallFactory {
+ using AsyncService = typename GrpcService::AsyncService;
+
+ public:
+ /// Constructor.
+ ///
+ /// \param[in] service The gRPC-generated `AsyncService`.
+ /// \param[in] request_call_function Pointer to the `AsyncService::RequestMethod`
+ // function.
+ /// \param[in] service_handler The service handler that handles the request.
+ /// \param[in] handle_request_function Pointer to the service handler function.
+ /// \param[in] cq The `CompletionQueue`.
+ ServerCallFactoryImpl(
+ AsyncService &service,
+ RequestCallFunction request_call_function,
+ ServiceHandler &service_handler,
+ HandleRequestFunction handle_request_function,
+ const std::unique_ptr &cq)
+ : service_(service),
+ request_call_function_(request_call_function),
+ service_handler_(service_handler),
+ handle_request_function_(handle_request_function),
+ cq_(cq) {}
+
+ ServerCall *CreateCall() const override {
+ // Create a new `ServerCall`. This object will eventually be deleted by
+ // `GrpcServer::PollEventsFromCompletionQueue`.
+ auto call = new ServerCallImpl(
+ *this, service_handler_, handle_request_function_);
+ /// Request gRPC runtime to starting accepting this kind of request, using the call as
+ /// the tag.
+ (service_.*request_call_function_)(&call->context_, &call->request_,
+ &call->response_writer_, cq_.get(), cq_.get(),
+ call);
+ return call;
+ }
+
+ private:
+ /// The gRPC-generated `AsyncService`.
+ AsyncService &service_;
+
+ /// Pointer to the `AsyncService::RequestMethod` function.
+ RequestCallFunction request_call_function_;
+
+ /// The service handler that handles the request.
+ ServiceHandler &service_handler_;
+
+ /// Pointer to the service handler function.
+ HandleRequestFunction handle_request_function_;
+
+ /// The `CompletionQueue`.
+ const std::unique_ptr &cq_;
+};
+
+} // namespace rpc
+} // namespace ray
+
+#endif
diff --git a/src/ray/rpc/util.h b/src/ray/rpc/util.h
new file mode 100644
index 000000000..6ecc6c3c4
--- /dev/null
+++ b/src/ray/rpc/util.h
@@ -0,0 +1,33 @@
+#ifndef RAY_RPC_UTIL_H
+#define RAY_RPC_UTIL_H
+
+#include
+
+#include "ray/common/status.h"
+
+namespace ray {
+namespace rpc {
+
+/// Helper function that converts a ray status to gRPC status.
+inline grpc::Status RayStatusToGrpcStatus(const Status &ray_status) {
+ if (ray_status.ok()) {
+ return grpc::Status::OK;
+ } else {
+ // TODO(hchen): Use more specific error code.
+ return grpc::Status(grpc::StatusCode::UNKNOWN, ray_status.message());
+ }
+}
+
+/// Helper function that converts a gRPC status to ray status.
+inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) {
+ if (grpc_status.ok()) {
+ return Status::OK();
+ } else {
+ return Status::IOError(grpc_status.error_message());
+ }
+}
+
+} // namespace rpc
+} // namespace ray
+
+#endif