diff --git a/.bazelrc b/.bazelrc index 488b33101..3e3c3b6c4 100644 --- a/.bazelrc +++ b/.bazelrc @@ -2,3 +2,5 @@ build --compilation_mode=opt build --action_env=PATH build --action_env=PYTHON_BIN_PATH +# This workaround is needed due to https://github.com/bazelbuild/bazel/issues/4341 +build --per_file_copt="external/com_github_grpc_grpc/.*@-DGRPC_BAZEL_BUILD" diff --git a/.travis.yml b/.travis.yml index 1888fa4ce..9a4fb66d8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -126,7 +126,7 @@ matrix: - ./ci/suppress_output ./ci/travis/install-dependencies.sh # This command should be kept in sync with ray/python/README-building-wheels.md. - - ./python/build-wheel-macos.sh + - ./ci/suppress_output ./python/build-wheel-macos.sh script: - if [ $RAY_CI_MACOS_WHEELS_AFFECTED != "1" ]; then exit; fi diff --git a/BUILD.bazel b/BUILD.bazel index 47e795011..da36eec0c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1,12 +1,37 @@ # Bazel build # C/C++ documentation: https://docs.bazel.build/versions/master/be/c-cpp.html +load("@com_github_grpc_grpc//bazel:grpc_build_system.bzl", "grpc_proto_library") load("@com_github_google_flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") load("@//bazel:ray.bzl", "flatbuffer_py_library") load("@//bazel:cython_library.bzl", "pyx_library") COPTS = ["-DRAY_USE_GLOG"] +# Node manager gRPC lib. +grpc_proto_library( + name = "node_manager_grpc_lib", + srcs = ["src/ray/protobuf/node_manager.proto"], +) + +# Node manager server and client. +cc_library( + name = "node_manager_rpc_lib", + srcs = glob([ + "src/ray/rpc/*.cc", + ]), + hdrs = glob([ + "src/ray/rpc/*.h", + ]), + copts = COPTS, + deps = [ + ":node_manager_grpc_lib", + ":ray_common", + "@boost//:asio", + "@com_github_grpc_grpc//:grpc++", + ], +) + cc_binary( name = "raylet", srcs = ["src/ray/raylet/main.cc"], @@ -89,6 +114,7 @@ cc_library( ":gcs", ":gcs_fbs", ":node_manager_fbs", + ":node_manager_rpc_lib", ":object_manager", ":ray_common", ":ray_util", diff --git a/bazel/ray_deps_build_all.bzl b/bazel/ray_deps_build_all.bzl index 5598d5820..3e1e1838a 100644 --- a/bazel/ray_deps_build_all.bzl +++ b/bazel/ray_deps_build_all.bzl @@ -3,6 +3,8 @@ load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps") load("@com_github_jupp0r_prometheus_cpp//:repositories.bzl", "prometheus_cpp_repositories") load("@com_github_ray_project_ray//bazel:python_configure.bzl", "python_configure") load("@com_github_checkstyle_java//:repo.bzl", "checkstyle_deps") +load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps") + def ray_deps_build_all(): gen_java_deps() @@ -10,3 +12,5 @@ def ray_deps_build_all(): boost_deps() prometheus_cpp_repositories() python_configure(name = "local_config_python") + grpc_deps() + diff --git a/bazel/ray_deps_setup.bzl b/bazel/ray_deps_setup.bzl index b3cd21b9b..e6dc21585 100644 --- a/bazel/ray_deps_setup.bzl +++ b/bazel/ray_deps_setup.bzl @@ -101,3 +101,11 @@ def ray_deps_setup(): # `https://github.com/jupp0r/prometheus-cpp/pull/225` getting merged. urls = ["https://github.com/jovany-wang/prometheus-cpp/archive/master.zip"], ) + + http_archive( + name = "com_github_grpc_grpc", + urls = [ + "https://github.com/grpc/grpc/archive/7741e806a213cba63c96234f16d712a8aa101a49.tar.gz", + ], + strip_prefix = "grpc-7741e806a213cba63c96234f16d712a8aa101a49", + ) diff --git a/ci/suppress_output b/ci/suppress_output index 623559d11..0f32b1a88 100755 --- a/ci/suppress_output +++ b/ci/suppress_output @@ -23,7 +23,7 @@ time "$@" >$TMPFILE 2>&1 CODE=$? if [ $CODE != 0 ]; then - cat $TMPFILE + tail -n 2000 $TMPFILE echo "FAILED $CODE" kill $WATCHDOG_PID exit $CODE diff --git a/ci/travis/install-bazel.sh b/ci/travis/install-bazel.sh index c9614f772..5b6d95729 100755 --- a/ci/travis/install-bazel.sh +++ b/ci/travis/install-bazel.sh @@ -16,7 +16,7 @@ else exit 1 fi -URL="https://github.com/bazelbuild/bazel/releases/download/0.21.0/bazel-0.21.0-installer-${platform}-x86_64.sh" +URL="https://github.com/bazelbuild/bazel/releases/download/0.26.1/bazel-0.26.1-installer-${platform}-x86_64.sh" wget -O install.sh $URL chmod +x install.sh ./install.sh --user diff --git a/java/BUILD.bazel b/java/BUILD.bazel index f3ae6f063..80ccabccf 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -94,11 +94,14 @@ define_java_module( ":org_ray_ray_api", ":org_ray_ray_runtime", "@plasma//:org_apache_arrow_arrow_plasma", + "@maven//:com_google_guava_guava", + "@maven//:com_sun_xml_bind_jaxb_core", + "@maven//:com_sun_xml_bind_jaxb_impl", + "@maven//:commons_io_commons_io", + "@maven//:javax_xml_bind_jaxb_api", "@maven//:org_apache_commons_commons_lang3", "@maven//:org_slf4j_slf4j_api", "@maven//:org_testng_testng", - "@maven//:com_google_guava_guava", - "@maven//:commons_io_commons_io", ], ) diff --git a/java/test/pom.xml b/java/test/pom.xml index 10f7ea4b3..6a3a31d20 100644 --- a/java/test/pom.xml +++ b/java/test/pom.xml @@ -32,11 +32,26 @@ guava 27.0.1-jre + + com.sun.xml.bind + jaxb-core + 2.3.0 + + + com.sun.xml.bind + jaxb-impl + 2.3.0 + commons-io commons-io 2.5 + + javax.xml.bind + jaxb-api + 2.3.0 + org.apache.commons commons-lang3 diff --git a/python/ray/tests/perf_integration_tests/test_perf_integration.py b/python/ray/tests/perf_integration_tests/test_perf_integration.py index 2ce2a305a..ff34fe412 100644 --- a/python/ray/tests/perf_integration_tests/test_perf_integration.py +++ b/python/ray/tests/perf_integration_tests/test_perf_integration.py @@ -6,6 +6,7 @@ import numpy as np import pytest import ray +from ray.tests.conftest import _ray_start_cluster num_tasks_submitted = [10**n for n in range(0, 6)] num_tasks_ids = ["{}_tasks".format(i) for i in num_tasks_submitted] @@ -41,3 +42,25 @@ def test_task_submission(benchmark, num_tasks): warmup() benchmark(benchmark_task_submission, num_tasks) ray.shutdown() + + +def benchmark_task_forward(f, num_tasks): + ray.get([f.remote() for _ in range(num_tasks)]) + + +@pytest.mark.benchmark +@pytest.mark.parametrize( + "num_tasks", [10**3, 10**4], + ids=[str(num) + "_tasks" for num in [10**3, 10**4]]) +def test_task_forward(benchmark, num_tasks): + with _ray_start_cluster(num_cpus=16, object_store_memory=10**6) as cluster: + cluster.add_node(resources={"my_resource": 100}) + ray.init(redis_address=cluster.redis_address) + + @ray.remote(resources={"my_resource": 0.001}) + def f(): + return 1 + + # Warm up + ray.get([f.remote() for _ in range(100)]) + benchmark(benchmark_task_forward, f, num_tasks) diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto new file mode 100644 index 000000000..8a82da1c7 --- /dev/null +++ b/src/ray/protobuf/node_manager.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; + +package ray.rpc; + +message ForwardTaskRequest { + // The ID of the task to be forwarded. + bytes task_id = 1; + // The tasks in the uncommitted lineage of the forwarded task. This + // should include task_id. + // TODO(hchen): Currently, `uncommitted_tasks` are represented as + // flatbutters-serialized bytes. This is because the flatbuffers-defined Task data + // structure is being used in many places. We should move Task and all related data + // strucutres to protobuf. + repeated bytes uncommitted_tasks = 2; +} + +message ForwardTaskReply { +} + +// Service for inter-node-manager communication. +service NodeManagerService { + // Forward a task and its uncommitted lineage to the remote node manager. + rpc ForwardTask(ForwardTaskRequest) returns (ForwardTaskReply); +} diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 671a7a798..a0bde1ff0 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -99,9 +99,9 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, lineage_cache_(gcs_client_->client_table().GetLocalClientId(), gcs_client_->raylet_task_table(), gcs_client_->raylet_task_table(), config.max_lineage_size), - remote_clients_(), - remote_server_connections_(), - actor_registry_() { + actor_registry_(), + node_manager_server_(config.node_manager_port, io_service, *this), + client_call_manager_(io_service) { RAY_CHECK(heartbeat_period_.count() > 0); // Initialize the resource map with own cluster resource configuration. ClientID local_client_id = gcs_client_->client_table().GetLocalClientId(); @@ -117,6 +117,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, [this](const ObjectID &object_id) { HandleObjectMissing(object_id); })); RAY_ARROW_CHECK_OK(store_client_.Connect(config.store_socket_name.c_str())); + // Run the node manger rpc server. + node_manager_server_.Run(); } ray::Status NodeManager::RegisterGcs() { @@ -366,66 +368,24 @@ void NodeManager::ClientAdded(const ClientTableDataT &client_data) { return; } - // TODO(atumanov): make remote client lookup O(1) - if (std::find(remote_clients_.begin(), remote_clients_.end(), client_id) == - remote_clients_.end()) { - remote_clients_.push_back(client_id); - } else { - // NodeManager connection to this client was already established. - RAY_LOG(DEBUG) << "Received a new client connection that already exists: " + auto entry = remote_node_manager_clients_.find(client_id); + if (entry != remote_node_manager_clients_.end()) { + RAY_LOG(DEBUG) << "Received notification of a new client that already exists: " << client_id; return; } - // Establish a new NodeManager connection to this GCS client. - auto status = ConnectRemoteNodeManager(client_id, client_data.node_manager_address, - client_data.node_manager_port); - if (!status.ok()) { - // This is not a fatal error for raylet, but it should not happen. - // We need to broadcase this message. - std::string type = "raylet_connection_error"; - std::ostringstream error_message; - error_message << "Failed to connect to ray node " << client_id - << " with status: " << status.ToString() - << ". This may be since the node was recently removed."; - // We use the nil DriverID to broadcast the message to all drivers. - RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - DriverID::Nil(), type, error_message.str(), current_time_ms())); - return; - } + // Initialize a rpc client to the new node manager. + std::unique_ptr client( + new rpc::NodeManagerClient(client_data.node_manager_address, + client_data.node_manager_port, client_call_manager_)); + remote_node_manager_clients_.emplace(client_id, std::move(client)); ResourceSet resources_total(client_data.resources_total_label, client_data.resources_total_capacity); cluster_resource_map_.emplace(client_id, SchedulingResources(resources_total)); } -ray::Status NodeManager::ConnectRemoteNodeManager(const ClientID &client_id, - const std::string &client_address, - int32_t client_port) { - // Establish a new NodeManager connection to this GCS client. - RAY_LOG(INFO) << "[ConnectClient] Trying to connect to client " << client_id << " at " - << client_address << ":" << client_port; - - boost::asio::ip::tcp::socket socket(io_service_); - RAY_RETURN_NOT_OK(TcpConnect(socket, client_address, client_port)); - - // The client is connected, now send a connect message to remote node manager. - auto server_conn = TcpServerConnection::Create(std::move(socket)); - - // Prepare client connection info buffer - flatbuffers::FlatBufferBuilder fbb; - auto message = protocol::CreateConnectClient(fbb, to_flatbuf(fbb, client_id_)); - fbb.Finish(message); - // Send synchronously. - // TODO(swang): Make this a WriteMessageAsync. - RAY_RETURN_NOT_OK(server_conn->WriteMessage( - static_cast(protocol::MessageType::ConnectClient), fbb.GetSize(), - fbb.GetBufferPointer())); - - remote_server_connections_.emplace(client_id, std::move(server_conn)); - return ray::Status::OK(); -} - void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // TODO(swang): If we receive a notification for our own death, clean up and // exit immediately. @@ -440,17 +400,13 @@ void NodeManager::ClientRemoved(const ClientTableDataT &client_data) { // check that it is actually removed, or log a warning otherwise, but that may // not be necessary. - // Remove the client from the list of remote clients. - std::remove(remote_clients_.begin(), remote_clients_.end(), client_id); - // Remove the client from the resource map. cluster_resource_map_.erase(client_id); - // Remove the remote server connection. - const auto connection_entry = remote_server_connections_.find(client_id); - if (connection_entry != remote_server_connections_.end()) { - connection_entry->second->Close(); - remote_server_connections_.erase(connection_entry); + // Remove the node manager client. + const auto client_entry = remote_node_manager_clients_.find(client_id); + if (client_entry != remote_node_manager_clients_.end()) { + remote_node_manager_clients_.erase(client_entry); } else { RAY_LOG(WARNING) << "Received ClientRemoved callback for an unknown client " << client_id << "."; @@ -1241,41 +1197,24 @@ void NodeManager::ProcessNewNodeManager(TcpClientConnection &node_manager_client node_manager_client.ProcessMessages(); } -void NodeManager::ProcessNodeManagerMessage(TcpClientConnection &node_manager_client, - int64_t message_type, - const uint8_t *message_data) { - const auto message_type_value = static_cast(message_type); - RAY_LOG(DEBUG) << "[NodeManager] Message " - << protocol::EnumNameMessageType(message_type_value) << "(" - << message_type << ") from node manager"; - switch (message_type_value) { - case protocol::MessageType::ConnectClient: { - auto message = flatbuffers::GetRoot(message_data); - auto client_id = from_flatbuf(*message->client_id()); - node_manager_client.SetClientID(client_id); - } break; - case protocol::MessageType::ForwardTaskRequest: { - auto message = flatbuffers::GetRoot(message_data); - TaskID task_id = from_flatbuf(*message->task_id()); - - Lineage uncommitted_lineage(*message); - const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData(); - RAY_LOG(DEBUG) << "Received forwarded task " << task.GetTaskSpecification().TaskId() - << " on node " << gcs_client_->client_table().GetLocalClientId() - << " spillback=" << task.GetTaskExecutionSpec().NumForwards(); - SubmitTask(task, uncommitted_lineage, /* forwarded = */ true); - } break; - case protocol::MessageType::DisconnectClient: { - // TODO(rkn): We need to do some cleanup here. - RAY_LOG(DEBUG) << "Received disconnect message from remote node manager. " - << "We need to do some cleanup here."; - // Do not process any more messages from this node manager. - return; - } break; - default: - RAY_LOG(FATAL) << "Received unexpected message type " << message_type; +void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request, + rpc::ForwardTaskReply *reply, + rpc::RequestDoneCallback done_callback) { + // Get the forwarded task and its uncommitted lineage from the request. + TaskID task_id = TaskID::FromBinary(request.task_id()); + Lineage uncommitted_lineage; + for (int i = 0; i < request.uncommitted_tasks_size(); i++) { + const std::string &task_message = request.uncommitted_tasks(i); + const Task task(*flatbuffers::GetRoot( + reinterpret_cast(task_message.data()))); + RAY_CHECK(uncommitted_lineage.SetEntry(std::move(task), GcsStatus::UNCOMMITTED)); } - node_manager_client.ProcessMessages(); + const Task &task = uncommitted_lineage.GetEntry(task_id)->TaskData(); + RAY_LOG(DEBUG) << "Received forwarded task " << task.GetTaskSpecification().TaskId() + << " on node " << gcs_client_->client_table().GetLocalClientId() + << " spillback=" << task.GetTaskExecutionSpec().NumForwards(); + SubmitTask(task, uncommitted_lineage, /* forwarded = */ true); + done_callback(Status::OK()); } void NodeManager::ProcessSetResourceRequest( @@ -2253,6 +2192,16 @@ void NodeManager::ForwardTaskOrResubmit(const Task &task, void NodeManager::ForwardTask( const Task &task, const ClientID &node_id, const std::function &on_error) { + // Lookup node manager client for this node_id and use it to send the request. + auto client_entry = remote_node_manager_clients_.find(node_id); + if (client_entry == remote_node_manager_clients_.end()) { + // TODO(atumanov): caller must handle failure to ensure tasks are not lost. + RAY_LOG(INFO) << "No node manager client found for GCS client id " << node_id; + on_error(ray::Status::IOError("Node manager client not found"), task); + return; + } + auto &client = client_entry->second; + const auto &spec = task.GetTaskSpecification(); auto task_id = spec.TaskId(); @@ -2272,68 +2221,61 @@ void NodeManager::ForwardTask( // Increment forward count for the forwarded task. lineage_cache_entry_task.IncrementNumForwards(); - flatbuffers::FlatBufferBuilder fbb; - auto request = uncommitted_lineage.ToFlatbuffer(fbb, task_id); - fbb.Finish(request); - RAY_LOG(DEBUG) << "Forwarding task " << task_id << " from " << gcs_client_->client_table().GetLocalClientId() << " to " << node_id << " spillback=" << lineage_cache_entry_task.GetTaskExecutionSpec().NumForwards(); - // Lookup remote server connection for this node_id and use it to send the request. - auto it = remote_server_connections_.find(node_id); - if (it == remote_server_connections_.end()) { - // TODO(atumanov): caller must handle failure to ensure tasks are not lost. - RAY_LOG(INFO) << "No NodeManager connection found for GCS client id " << node_id; - on_error(ray::Status::IOError("NodeManager connection not found"), task); - return; + // Prepare the request message. + rpc::ForwardTaskRequest request; + request.set_task_id(task_id.Binary()); + for (auto &entry : uncommitted_lineage.GetEntries()) { + request.add_uncommitted_tasks(entry.second.TaskData().Serialize()); } - auto &server_conn = it->second; + // Move the FORWARDING task to the SWAP queue so that we remember that we // have it queued locally. Once the ForwardTaskRequest has been sent, the // task will get re-queued, depending on whether the message succeeded or // not. local_queues_.QueueTasks({task}, TaskState::SWAP); - server_conn->WriteMessageAsync( - static_cast(protocol::MessageType::ForwardTaskRequest), fbb.GetSize(), - fbb.GetBufferPointer(), [this, on_error, task_id, node_id](ray::Status status) { - // Remove the FORWARDING task from the SWAP queue. - TaskState state; - const auto task = local_queues_.RemoveTask(task_id, &state); - RAY_CHECK(state == TaskState::SWAP); + client->ForwardTask(request, [this, on_error, task_id, node_id]( + Status status, const rpc::ForwardTaskReply &reply) { + // Remove the FORWARDING task from the SWAP queue. + TaskState state; + const auto task = local_queues_.RemoveTask(task_id, &state); + RAY_CHECK(state == TaskState::SWAP); - if (status.ok()) { - const auto &spec = task.GetTaskSpecification(); - // Mark as forwarded so that the task and its lineage are not - // re-forwarded in the future to the receiving node. - lineage_cache_.MarkTaskAsForwarded(task_id, node_id); + if (status.ok()) { + const auto &spec = task.GetTaskSpecification(); + // Mark as forwarded so that the task and its lineage are not + // re-forwarded in the future to the receiving node. + lineage_cache_.MarkTaskAsForwarded(task_id, node_id); - // Notify the task dependency manager that we are no longer responsible - // for executing this task. - task_dependency_manager_.TaskCanceled(task_id); - // Preemptively push any local arguments to the receiving node. For now, we - // only do this with actor tasks, since actor tasks must be executed by a - // specific process and therefore have affinity to the receiving node. - if (spec.IsActorTask()) { - // Iterate through the object's arguments. NOTE(swang): We do not include - // the execution dependencies here since those cannot be transferred - // between nodes. - for (int i = 0; i < spec.NumArgs(); ++i) { - int count = spec.ArgIdCount(i); - for (int j = 0; j < count; j++) { - ObjectID argument_id = spec.ArgId(i, j); - // If the argument is local, then push it to the receiving node. - if (task_dependency_manager_.CheckObjectLocal(argument_id)) { - object_manager_.Push(argument_id, node_id); - } - } + // Notify the task dependency manager that we are no longer responsible + // for executing this task. + task_dependency_manager_.TaskCanceled(task_id); + // Preemptively push any local arguments to the receiving node. For now, we + // only do this with actor tasks, since actor tasks must be executed by a + // specific process and therefore have affinity to the receiving node. + if (spec.IsActorTask()) { + // Iterate through the object's arguments. NOTE(swang): We do not include + // the execution dependencies here since those cannot be transferred + // between nodes. + for (int i = 0; i < spec.NumArgs(); ++i) { + int count = spec.ArgIdCount(i); + for (int j = 0; j < count; j++) { + ObjectID argument_id = spec.ArgId(i, j); + // If the argument is local, then push it to the receiving node. + if (task_dependency_manager_.CheckObjectLocal(argument_id)) { + object_manager_.Push(argument_id, node_id); } } - } else { - on_error(status, task); } - }); + } + } else { + on_error(status, task); + } + }); } void NodeManager::DumpDebugState() const { @@ -2368,10 +2310,11 @@ std::string NodeManager::DebugString() const { result << "\n- num dead actors: " << statistical_data.dead_actors; result << "\n- max num handles: " << statistical_data.max_num_handles; - result << "\nRemoteConnections:"; - for (auto &pair : remote_server_connections_) { - result << "\n" << pair.first.Hex() << ": " << pair.second->DebugString(); + result << "\nRemote node manager clients: "; + for (const auto &entry : remote_node_manager_clients_) { + result << "\n" << entry.first; } + result << "\nDebugString() time ms: " << (current_time_ms() - now_ms); return result.str(); } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 3f7e4d7da..616133583 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -4,6 +4,9 @@ #include // clang-format off +#include "ray/rpc/client_call.h" +#include "ray/rpc/node_manager_server.h" +#include "ray/rpc/node_manager_client.h" #include "ray/raylet/task.h" #include "ray/object_manager/object_manager.h" #include "ray/common/client_connection.h" @@ -52,7 +55,7 @@ struct NodeManagerConfig { std::string session_dir; }; -class NodeManager { +class NodeManager : public rpc::NodeManagerServiceHandler { public: /// Create a node manager. /// @@ -86,15 +89,6 @@ class NodeManager { /// \return Void. void ProcessNewNodeManager(TcpClientConnection &node_manager_client); - /// Handle a message from a remote node manager. - /// - /// \param node_manager_client The connection to the remote node manager. - /// \param message_type The type of the message. - /// \param message The message contents. - /// \return Void. - void ProcessNodeManagerMessage(TcpClientConnection &node_manager_client, - int64_t message_type, const uint8_t *message); - /// Subscribe to the relevant GCS tables and set up handlers. /// /// \return Status indicating whether this was done successfully or not. @@ -108,6 +102,9 @@ class NodeManager { /// Record metrics. void RecordMetrics() const; + /// Get the port of the node manager rpc server. + int GetServerPort() const { return node_manager_server_.GetPort(); } + private: /// Methods for handling clients. @@ -450,15 +447,10 @@ class NodeManager { void HandleDisconnectedActor(const ActorID &actor_id, bool was_local, bool intentional_disconnect); - /// connect to a remote node manager. - /// - /// \param client_id The client ID for the remote node manager. - /// \param client_address The IP address for the remote node manager. - /// \param client_port The listening port for the remote node manager. - /// \return True if the connect succeeds. - ray::Status ConnectRemoteNodeManager(const ClientID &client_id, - const std::string &client_address, - int32_t client_port); + /// Handle a `ForwardTask` request. + void HandleForwardTask(const rpc::ForwardTaskRequest &request, + rpc::ForwardTaskReply *reply, + rpc::RequestDoneCallback done_callback) override; // GCS client ID for this node. ClientID client_id_; @@ -505,9 +497,6 @@ class NodeManager { TaskDependencyManager task_dependency_manager_; /// The lineage cache for the GCS object and task tables. LineageCache lineage_cache_; - std::vector remote_clients_; - std::unordered_map> - remote_server_connections_; /// A mapping from actor ID to registration information about that actor /// (including which node manager owns it). std::unordered_map actor_registry_; @@ -515,6 +504,16 @@ class NodeManager { /// This map stores actor ID to the ID of the checkpoint that will be used to /// restore the actor. std::unordered_map checkpoint_id_to_restore_; + + /// The RPC server. + rpc::NodeManagerServer node_manager_server_; + + /// The `ClientCallManager` object that is shared by all `NodeManagerClient`s. + rpc::ClientCallManager client_call_manager_; + + /// Map from node ids to clients of the remote node managers. + std::unordered_map> + remote_node_manager_clients_; }; } // namespace raylet diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index 80630d372..473e6c263 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -61,15 +61,10 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_ main_service, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), object_manager_config.object_manager_port)), - object_manager_socket_(main_service), - node_manager_acceptor_(main_service, boost::asio::ip::tcp::endpoint( - boost::asio::ip::tcp::v4(), - node_manager_config.node_manager_port)), - node_manager_socket_(main_service) { + object_manager_socket_(main_service) { // Start listening for clients. DoAccept(); DoAcceptObjectManager(); - DoAcceptNodeManager(); RAY_CHECK_OK(RegisterGcs( node_ip_address, socket_name_, object_manager_config.store_socket_name, @@ -100,7 +95,7 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, client_info.raylet_socket_name = raylet_socket_name; client_info.object_store_socket_name = object_store_socket_name; client_info.object_manager_port = object_manager_acceptor_.local_endpoint().port(); - client_info.node_manager_port = node_manager_acceptor_.local_endpoint().port(); + client_info.node_manager_port = node_manager_.GetServerPort(); // Add resource information. for (const auto &resource_pair : node_manager_config.resource_config.GetResourceMap()) { client_info.resources_total_label.push_back(resource_pair.first); @@ -120,33 +115,6 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, return Status::OK(); } -void Raylet::DoAcceptNodeManager() { - node_manager_acceptor_.async_accept(node_manager_socket_, - boost::bind(&Raylet::HandleAcceptNodeManager, this, - boost::asio::placeholders::error)); -} - -void Raylet::HandleAcceptNodeManager(const boost::system::error_code &error) { - if (!error) { - ClientHandler client_handler = - [this](TcpClientConnection &client) { - node_manager_.ProcessNewNodeManager(client); - }; - MessageHandler message_handler = - [this](std::shared_ptr client, int64_t message_type, - const uint8_t *message) { - node_manager_.ProcessNodeManagerMessage(*client, message_type, message); - }; - // Accept a new TCP client and dispatch it to the node manager. - auto new_connection = TcpClientConnection::Create( - client_handler, message_handler, std::move(node_manager_socket_), "node manager", - node_manager_message_enum, - static_cast(protocol::MessageType::DisconnectClient)); - } - // We're ready to accept another client. - DoAcceptNodeManager(); -} - void Raylet::DoAcceptObjectManager() { object_manager_acceptor_.async_accept( object_manager_socket_, boost::bind(&Raylet::HandleAcceptObjectManager, this, diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index 84274ea6e..26fe74b2b 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -63,8 +63,6 @@ class Raylet { void DoAcceptObjectManager(); /// Handle an accepted tcp client connection. void HandleAcceptObjectManager(const boost::system::error_code &error); - void DoAcceptNodeManager(); - void HandleAcceptNodeManager(const boost::system::error_code &error); friend class TestObjectManagerIntegration; @@ -88,10 +86,6 @@ class Raylet { boost::asio::ip::tcp::acceptor object_manager_acceptor_; /// The socket to listen on for new object manager tcp clients. boost::asio::ip::tcp::socket object_manager_socket_; - /// An acceptor for new tcp clients. - boost::asio::ip::tcp::acceptor node_manager_acceptor_; - /// The socket to listen on for new tcp clients. - boost::asio::ip::tcp::socket node_manager_socket_; }; } // namespace raylet diff --git a/src/ray/raylet/task.cc b/src/ray/raylet/task.cc index 5d6a02186..9d8036411 100644 --- a/src/ray/raylet/task.cc +++ b/src/ray/raylet/task.cc @@ -46,14 +46,18 @@ void Task::CopyTaskExecutionSpec(const Task &task) { ComputeDependencies(); } +const std::string Task::Serialize() const { + flatbuffers::FlatBufferBuilder fbb; + fbb.Finish(ToFlatbuffer(fbb)); + return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize()); +} + std::string SerializeTaskAsString(const std::vector *dependencies, const TaskSpecification *task_spec) { - flatbuffers::FlatBufferBuilder fbb; std::vector execution_dependencies(*dependencies); TaskExecutionSpecification execution_spec(std::move(execution_dependencies)); Task task(execution_spec, *task_spec); - fbb.Finish(task.ToFlatbuffer(fbb)); - return std::string(fbb.GetBufferPointer(), fbb.GetBufferPointer() + fbb.GetSize()); + return task.Serialize(); } } // namespace raylet diff --git a/src/ray/raylet/task.h b/src/ray/raylet/task.h index b942e2bf2..10cdfe511 100644 --- a/src/ray/raylet/task.h +++ b/src/ray/raylet/task.h @@ -84,6 +84,9 @@ class Task { /// \param task Task structure with updated dynamic information. void CopyTaskExecutionSpec(const Task &task); + /// Serialize this task as a string. + const std::string Serialize() const; + private: void ComputeDependencies(); diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h new file mode 100644 index 000000000..725652cb5 --- /dev/null +++ b/src/ray/rpc/client_call.h @@ -0,0 +1,169 @@ +#ifndef RAY_RPC_CLIENT_CALL_H +#define RAY_RPC_CLIENT_CALL_H + +#include +#include + +#include "ray/common/status.h" +#include "ray/rpc/util.h" + +namespace ray { +namespace rpc { + +/// Represents an outgoing gRPC request. +/// +/// The lifecycle of a `ClientCall` is as follows. +/// +/// When a client submits a new gRPC request, a new `ClientCall` object will be created +/// by `ClientCallMangager::CreateCall`. Then the object will be used as the tag of +/// `CompletionQueue`. +/// +/// When the reply is received, `ClientCallMangager` will get the address of this object +/// via `CompletionQueue`'s tag. And the manager should call `OnReplyReceived` and then +/// delete this object. +/// +/// NOTE(hchen): Compared to `ClientCallImpl`, this abstract interface doesn't use +/// template. This allows the users (e.g., `ClientCallMangager`) not having to use +/// template as well. +class ClientCall { + public: + /// The callback to be called by `ClientCallManager` when the reply of this request is + /// received. + virtual void OnReplyReceived() = 0; +}; + +class ClientCallManager; + +/// Reprents the client callback function of a particular rpc method. +/// +/// \tparam Reply Type of the reply message. +template +using ClientCallback = std::function; + +/// Implementaion of the `ClientCall`. It represents a `ClientCall` for a particular +/// RPC method. +/// +/// \tparam Reply Type of the Reply message. +template +class ClientCallImpl : public ClientCall { + public: + void OnReplyReceived() override { callback_(GrpcStatusToRayStatus(status_), reply_); } + + private: + /// Constructor. + /// + /// \param[in] callback The callback function to handle the reply. + ClientCallImpl(const ClientCallback &callback) : callback_(callback) {} + + /// The reply message. + Reply reply_; + + /// The callback function to handle the reply. + ClientCallback callback_; + + /// The response reader. + std::unique_ptr> response_reader_; + + /// gRPC status of this request. + grpc::Status status_; + + /// Context for the client. It could be used to convey extra information to + /// the server and/or tweak certain RPC behaviors. + grpc::ClientContext context_; + + friend class ClientCallManager; +}; + +/// Peprents the generic signature of a `FooService::Stub::PrepareAsyncBar` +/// function, where `Foo` is the service name and `Bar` is the rpc method name. +/// +/// \tparam GrpcService Type of the gRPC-generated service class. +/// \tparam Request Type of the request message. +/// \tparam Reply Type of the reply message. +template +using PrepareAsyncFunction = std::unique_ptr> ( + GrpcService::Stub::*)(grpc::ClientContext *context, const Request &request, + grpc::CompletionQueue *cq); + +/// `ClientCallManager` is used to manage outgoing gRPC requests and the lifecycles of +/// `ClientCall` objects. +/// +/// It maintains a thread that keeps polling events from `CompletionQueue`, and post +/// the callback function to the main event loop when a reply is received. +/// +/// Mutiple clients can share one `ClientCallManager`. +class ClientCallManager { + public: + /// Constructor. + /// + /// \param[in] main_service The main event loop, to which the callback functions will be + /// posted. + ClientCallManager(boost::asio::io_service &main_service) : main_service_(main_service) { + // Start the polling thread. + std::thread polling_thread(&ClientCallManager::PollEventsFromCompletionQueue, this); + polling_thread.detach(); + } + + ~ClientCallManager() { cq_.Shutdown(); } + + /// Create a new `ClientCall` and send request. + /// + /// \param[in] stub The gRPC-generated stub. + /// \param[in] prepare_async_function Pointer to the gRPC-generated + /// `FooService::Stub::PrepareAsyncBar` function. + /// \param[in] request The request message. + /// \param[in] callback The callback function that handles reply. + /// + /// \tparam GrpcService Type of the gRPC-generated service class. + /// \tparam Request Type of the request message. + /// \tparam Reply Type of the reply message. + template + ClientCall *CreateCall( + typename GrpcService::Stub &stub, + const PrepareAsyncFunction prepare_async_function, + const Request &request, const ClientCallback &callback) { + // Create a new `ClientCall` object. This object will eventuall be deleted in the + // `ClientCallManager::PollEventsFromCompletionQueue` when reply is received. + auto call = new ClientCallImpl(callback); + // Send request. + call->response_reader_ = + (stub.*prepare_async_function)(&call->context_, request, &cq_); + call->response_reader_->StartCall(); + call->response_reader_->Finish(&call->reply_, &call->status_, (void *)call); + return call; + } + + private: + /// This function runs in a background thread. It keeps polling events from the + /// `CompletionQueue`, and dispaches the event to the callbacks via the `ClientCall` + /// objects. + void PollEventsFromCompletionQueue() { + void *got_tag; + bool ok = false; + // Keep reading events from the `CompletionQueue` until it's shutdown. + while (cq_.Next(&got_tag, &ok)) { + ClientCall *call = reinterpret_cast(got_tag); + if (ok) { + // Post the callback to the main event loop. + main_service_.post([call]() { + call->OnReplyReceived(); + // The call is finished, we can delete the `ClientCall` object now. + delete call; + }); + } else { + delete call; + } + } + } + + /// The main event loop, to which the callback functions will be posted. + boost::asio::io_service &main_service_; + + /// The gRPC `CompletionQueue` object used to poll events. + grpc::CompletionQueue cq_; +}; + +} // namespace rpc +} // namespace ray + +#endif diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc new file mode 100644 index 000000000..feb788da7 --- /dev/null +++ b/src/ray/rpc/grpc_server.cc @@ -0,0 +1,70 @@ +#include "ray/rpc/grpc_server.h" + +namespace ray { +namespace rpc { + +void GrpcServer::Run() { + std::string server_address("0.0.0.0:" + std::to_string(port_)); + + grpc::ServerBuilder builder; + // TODO(hchen): Add options for authentication. + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_); + // Allow subclasses to register concrete services. + RegisterServices(builder); + // Get hold of the completion queue used for the asynchronous communication + // with the gRPC runtime. + cq_ = builder.AddCompletionQueue(); + // Build and start server. + server_ = builder.BuildAndStart(); + RAY_LOG(DEBUG) << name_ << " server started, listening on port " << port_ << "."; + + // Allow subclasses to initialize the server call factories. + InitServerCallFactories(&server_call_factories_and_concurrencies_); + for (auto &entry : server_call_factories_and_concurrencies_) { + for (int i = 0; i < entry.second; i++) { + // Create and request calls from the factory. + entry.first->CreateCall(); + } + } + // Start a thread that polls incoming requests. + std::thread polling_thread(&GrpcServer::PollEventsFromCompletionQueue, this); + polling_thread.detach(); +} + +void GrpcServer::PollEventsFromCompletionQueue() { + void *tag; + bool ok; + // Keep reading events from the `CompletionQueue` until it's shutdown. + while (cq_->Next(&tag, &ok)) { + ServerCall *server_call = static_cast(tag); + // `ok == false` indicates that the server has been shut down. + // We should delete the call object in this case. + bool delete_call = !ok; + if (ok) { + switch (server_call->GetState()) { + case ServerCallState::PENDING: + // We've received a new incoming request. Now this call object is used to + // track this request. So we need to create another call to handle next + // incoming request. + server_call->GetFactory().CreateCall(); + server_call->SetState(ServerCallState::PROCESSING); + main_service_.post([server_call] { server_call->HandleRequest(); }); + break; + case ServerCallState::SENDING_REPLY: + // The reply has been sent, this call can be deleted now. + // This event is triggered by `ServerCallImpl::SendReply`. + delete_call = true; + break; + default: + RAY_LOG(FATAL) << "Shouldn't reach here."; + break; + } + } + if (delete_call) { + delete server_call; + } + } +} + +} // namespace rpc +} // namespace ray diff --git a/src/ray/rpc/grpc_server.h b/src/ray/rpc/grpc_server.h new file mode 100644 index 000000000..4953f4706 --- /dev/null +++ b/src/ray/rpc/grpc_server.h @@ -0,0 +1,92 @@ +#ifndef RAY_RPC_GRPC_SERVER_H +#define RAY_RPC_GRPC_SERVER_H + +#include + +#include +#include + +#include "ray/common/status.h" +#include "ray/rpc/server_call.h" + +namespace ray { +namespace rpc { + +/// Base class that represents an abstract gRPC server. +/// +/// A `GrpcServer` listens on a specific port. It owns +/// 1) a `ServerCompletionQueue` that is used for polling events from gRPC, +/// 2) and a thread that polls events from the `ServerCompletionQueue`. +/// +/// Subclasses can register one or multiple services to a `GrpcServer`, see +/// `RegisterServices`. And they should also implement `InitServerCallFactories` to decide +/// which kinds of requests this server should accept. +class GrpcServer { + public: + /// Constructor. + /// + /// \param[in] name Name of this server, used for logging and debugging purpose. + /// \param[in] port The port to bind this server to. If it's 0, a random available port + /// will be chosen. + /// \param[in] main_service The main event loop, to which service handler functions + /// will be posted. + GrpcServer(const std::string &name, const uint32_t port, + boost::asio::io_service &main_service) + : name_(name), port_(port), main_service_(main_service) {} + + /// Destruct this gRPC server. + ~GrpcServer() { + server_->Shutdown(); + cq_->Shutdown(); + } + + /// Initialize and run this server. + void Run(); + + /// Get the port of this gRPC server. + int GetPort() const { return port_; } + + protected: + /// Subclasses should implement this method and register one or multiple gRPC services + /// to the given `ServerBuilder`. + /// + /// \param[in] builder The `ServerBuilder` instance to register services to. + virtual void RegisterServices(grpc::ServerBuilder &builder) = 0; + + /// Subclasses should implement this method to initialize the `ServerCallFactory` + /// instances, as well as specify maximum number of concurrent requests that gRPC + /// server can "accept" (not "handle"). Each factory will be used to create + /// `accept_concurrency` `ServerCall` objects, each of which will be used to accept and + /// handle an incoming request. + /// + /// \param[out] server_call_factories_and_concurrencies The `ServerCallFactory` objects, + /// and the maximum number of concurrent requests that gRPC server can accept. + virtual void InitServerCallFactories( + std::vector, int>> + *server_call_factories_and_concurrencies) = 0; + + /// This function runs in a background thread. It keeps polling events from the + /// `ServerCompletionQueue`, and dispaches the event to the `ServiceHandler` instances + /// via the `ServerCall` objects. + void PollEventsFromCompletionQueue(); + + /// The main event loop, to which the service handler functions will be posted. + boost::asio::io_service &main_service_; + /// Name of this server, used for logging and debugging purpose. + const std::string name_; + /// Port of this server. + int port_; + /// The `ServerCallFactory` objects, and the maximum number of concurrent requests that + /// gRPC server can accept. + std::vector, int>> + server_call_factories_and_concurrencies_; + /// The `ServerCompletionQueue` object used for polling events. + std::unique_ptr cq_; + /// The `Server` object. + std::unique_ptr server_; +}; + +} // namespace rpc +} // namespace ray + +#endif diff --git a/src/ray/rpc/node_manager_client.h b/src/ray/rpc/node_manager_client.h new file mode 100644 index 000000000..005c75db4 --- /dev/null +++ b/src/ray/rpc/node_manager_client.h @@ -0,0 +1,56 @@ +#ifndef RAY_RPC_NODE_MANAGER_CLIENT_H +#define RAY_RPC_NODE_MANAGER_CLIENT_H + +#include + +#include + +#include "ray/common/status.h" +#include "ray/rpc/client_call.h" +#include "ray/util/logging.h" +#include "src/ray/protobuf/node_manager.grpc.pb.h" +#include "src/ray/protobuf/node_manager.pb.h" + +namespace ray { +namespace rpc { + +/// Client used for communicating with a remote node manager server. +class NodeManagerClient { + public: + /// Constructor. + /// + /// \param[in] address Address of the node manager server. + /// \param[in] port Port of the node manager server. + /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. + NodeManagerClient(const std::string &address, const int port, + ClientCallManager &client_call_manager) + : client_call_manager_(client_call_manager) { + std::shared_ptr channel = grpc::CreateChannel( + address + ":" + std::to_string(port), grpc::InsecureChannelCredentials()); + stub_ = NodeManagerService::NewStub(channel); + }; + + /// Forward a task and its uncommitted lineage. + /// + /// \param[in] request The request message. + /// \param[in] callback The callback function that handles reply. + void ForwardTask(const ForwardTaskRequest &request, + const ClientCallback &callback) { + client_call_manager_ + .CreateCall( + *stub_, &NodeManagerService::Stub::PrepareAsyncForwardTask, request, + callback); + } + + private: + /// The gRPC-generated stub. + std::unique_ptr stub_; + + /// The `ClientCallManager` used for managing requests. + ClientCallManager &client_call_manager_; +}; + +} // namespace rpc +} // namespace ray + +#endif // RAY_RPC_NODE_MANAGER_CLIENT_H diff --git a/src/ray/rpc/node_manager_server.h b/src/ray/rpc/node_manager_server.h new file mode 100644 index 000000000..afaea299e --- /dev/null +++ b/src/ray/rpc/node_manager_server.h @@ -0,0 +1,71 @@ +#ifndef RAY_RPC_NODE_MANAGER_SERVER_H +#define RAY_RPC_NODE_MANAGER_SERVER_H + +#include "ray/rpc/grpc_server.h" +#include "ray/rpc/server_call.h" + +#include "src/ray/protobuf/node_manager.grpc.pb.h" +#include "src/ray/protobuf/node_manager.pb.h" + +namespace ray { +namespace rpc { + +/// Interface of the `NodeManagerService`, see `src/ray/protobuf/node_manager.proto`. +class NodeManagerServiceHandler { + public: + /// Handle a `ForwardTask` request. + /// The implementation can handle this request asynchronously. When hanling is done, the + /// `done_callback` should be called. + /// + /// \param[in] request The request message. + /// \param[out] reply The reply message. + /// \param[in] done_callback The callback to be called when the request is done. + virtual void HandleForwardTask(const ForwardTaskRequest &request, + ForwardTaskReply *reply, + RequestDoneCallback done_callback) = 0; +}; + +/// The `GrpcServer` for `NodeManagerService`. +class NodeManagerServer : public GrpcServer { + public: + /// Constructor. + /// + /// \param[in] port See super class. + /// \param[in] main_service See super class. + /// \param[in] handler The service handler that actually handle the requests. + NodeManagerServer(const uint32_t port, boost::asio::io_service &main_service, + NodeManagerServiceHandler &service_handler) + : GrpcServer("NodeManager", port, main_service), + service_handler_(service_handler){}; + + void RegisterServices(grpc::ServerBuilder &builder) override { + /// Register `NodeManagerService`. + builder.RegisterService(&service_); + } + + void InitServerCallFactories( + std::vector, int>> + *server_call_factories_and_concurrencies) override { + // Initialize the factory for `ForwardTask` requests. + std::unique_ptr forward_task_call_factory( + new ServerCallFactoryImpl( + service_, &NodeManagerService::AsyncService::RequestForwardTask, + service_handler_, &NodeManagerServiceHandler::HandleForwardTask, cq_)); + + // Set `ForwardTask`'s accept concurrency to 100. + server_call_factories_and_concurrencies->emplace_back( + std::move(forward_task_call_factory), 100); + } + + private: + /// The grpc async service object. + NodeManagerService::AsyncService service_; + /// The service handler that actually handle the requests. + NodeManagerServiceHandler &service_handler_; +}; + +} // namespace rpc +} // namespace ray + +#endif diff --git a/src/ray/rpc/server_call.h b/src/ray/rpc/server_call.h new file mode 100644 index 000000000..e06278260 --- /dev/null +++ b/src/ray/rpc/server_call.h @@ -0,0 +1,233 @@ +#ifndef RAY_RPC_SERVER_CALL_H +#define RAY_RPC_SERVER_CALL_H + +#include + +#include "ray/common/status.h" +#include "ray/rpc/util.h" + +namespace ray { +namespace rpc { + +/// Represents the callback function to be called when a `ServiceHandler` finishes +/// handling a request. +using RequestDoneCallback = std::function; + +/// Represents state of a `ServerCall`. +enum class ServerCallState { + /// The call is created and waiting for an incoming request. + PENDING, + /// Request is received and being processed. + PROCESSING, + /// Request processing is done, and reply is being sent to client. + SENDING_REPLY +}; + +class ServerCallFactory; + +/// Reprensents an incoming request of a gRPC server. +/// +/// The lifecycle and state transition of a `ServerCall` is as follows: +/// +/// --(1)--> PENDING --(2)--> PROCESSING --(3)--> SENDING_REPLY --(4)--> [FINISHED] +/// +/// (1) The `GrpcServer` creates a `ServerCall` and use it as the tag to accept requests +/// gRPC `CompletionQueue`. Now the state is `PENDING`. +/// (2) When a request is received, an event will be gotten from the `CompletionQueue`. +/// `GrpcServer` then should change `ServerCall`'s state to PROCESSING and call +/// `ServerCall::HandleRequest`. +/// (3) When the `ServiceHandler` finishes handling the request, `ServerCallImpl::Finish` +/// will be called, and the state becomes `SENDING_REPLY`. +/// (4) When the reply is sent, an event will be gotten from the `CompletionQueue`. +/// `GrpcServer` will then delete this call. +/// +/// NOTE(hchen): Compared to `ServerCallImpl`, this abstract interface doesn't use +/// template. This allows the users (e.g., `GrpcServer`) not having to use +/// template as well. +class ServerCall { + public: + /// Get the state of this `ServerCall`. + virtual ServerCallState GetState() const = 0; + + /// Set state of this `ServerCall`. + virtual void SetState(const ServerCallState &new_state) = 0; + + /// Handle the requst. This is the callback function to be called by + /// `GrpcServer` when the request is received. + virtual void HandleRequest() = 0; + + /// Get the factory that created this `ServerCall`. + virtual const ServerCallFactory &GetFactory() const = 0; +}; + +/// The factory that creates a particular kind of `ServerCall` objects. +class ServerCallFactory { + public: + /// Create a new `ServerCall` and request gRPC runtime to start accepting the + /// corresonding type of requests. + /// + /// \return Pointer to the `ServerCall` object. + virtual ServerCall *CreateCall() const = 0; +}; + +/// Represents the generic signature of a `FooServiceHandler::HandleBar()` +/// function, where `Foo` is the service name and `Bar` is the rpc method name. +/// +/// \tparam ServiceHandler Type of the handler that handles the request. +/// \tparam Request Type of the request message. +/// \tparam Reply Type of the reply message. +template +using HandleRequestFunction = void (ServiceHandler::*)(const Request &, Reply *, + RequestDoneCallback); + +/// Implementation of `ServerCall`. It represents `ServerCall` for a particular +/// RPC method. +/// +/// \tparam ServiceHandler Type of the handler that handles the request. +/// \tparam Request Type of the request message. +/// \tparam Reply Type of the reply message. +template +class ServerCallImpl : public ServerCall { + public: + /// Constructor. + /// + /// \param[in] factory The factory which created this call. + /// \param[in] service_handler The service handler that handles the request. + /// \param[in] handle_request_function Pointer to the service handler function. + ServerCallImpl( + const ServerCallFactory &factory, ServiceHandler &service_handler, + HandleRequestFunction handle_request_function) + : state_(ServerCallState::PENDING), + factory_(factory), + service_handler_(service_handler), + handle_request_function_(handle_request_function), + response_writer_(&context_) {} + + ServerCallState GetState() const override { return state_; } + + void SetState(const ServerCallState &new_state) override { state_ = new_state; } + + void HandleRequest() override { + state_ = ServerCallState::PROCESSING; + (service_handler_.*handle_request_function_)(request_, &reply_, + [this](Status status) { + // When the handler is done with the + // request, tell gRPC to finish this + // request. + SendReply(status); + }); + } + + const ServerCallFactory &GetFactory() const override { return factory_; } + + private: + /// Tell gRPC to finish this request. + void SendReply(Status status) { + state_ = ServerCallState::SENDING_REPLY; + response_writer_.Finish(reply_, RayStatusToGrpcStatus(status), this); + } + + /// State of this call. + ServerCallState state_; + + /// The factory which created this call. + const ServerCallFactory &factory_; + + /// The service handler that handles the request. + ServiceHandler &service_handler_; + + /// Pointer to the service handler function. + HandleRequestFunction handle_request_function_; + + /// Context for the request, allowing to tweak aspects of it such as the use + /// of compression, authentication, as well as to send metadata back to the client. + grpc::ServerContext context_; + + /// The reponse writer. + grpc::ServerAsyncResponseWriter response_writer_; + + /// The request message. + Request request_; + + /// The reply message. + Reply reply_; + + template + friend class ServerCallFactoryImpl; +}; + +/// Represents the generic signature of a `FooService::AsyncService::RequestBar()` +/// function, where `Foo` is the service name and `Bar` is the rpc method name. +/// \tparam GrpcService Type of the gRPC-generated service class. +/// \tparam Request Type of the request message. +/// \tparam Reply Type of the reply message. +template +using RequestCallFunction = void (GrpcService::AsyncService::*)( + grpc::ServerContext *, Request *, grpc::ServerAsyncResponseWriter *, + grpc::CompletionQueue *, grpc::ServerCompletionQueue *, void *); + +/// Implementation of `ServerCallFactory` +/// +/// \tparam GrpcService Type of the gRPC-generated service class. +/// \tparam ServiceHandler Type of the handler that handles the request. +/// \tparam Request Type of the request message. +/// \tparam Reply Type of the reply message. +template +class ServerCallFactoryImpl : public ServerCallFactory { + using AsyncService = typename GrpcService::AsyncService; + + public: + /// Constructor. + /// + /// \param[in] service The gRPC-generated `AsyncService`. + /// \param[in] request_call_function Pointer to the `AsyncService::RequestMethod` + // function. + /// \param[in] service_handler The service handler that handles the request. + /// \param[in] handle_request_function Pointer to the service handler function. + /// \param[in] cq The `CompletionQueue`. + ServerCallFactoryImpl( + AsyncService &service, + RequestCallFunction request_call_function, + ServiceHandler &service_handler, + HandleRequestFunction handle_request_function, + const std::unique_ptr &cq) + : service_(service), + request_call_function_(request_call_function), + service_handler_(service_handler), + handle_request_function_(handle_request_function), + cq_(cq) {} + + ServerCall *CreateCall() const override { + // Create a new `ServerCall`. This object will eventually be deleted by + // `GrpcServer::PollEventsFromCompletionQueue`. + auto call = new ServerCallImpl( + *this, service_handler_, handle_request_function_); + /// Request gRPC runtime to starting accepting this kind of request, using the call as + /// the tag. + (service_.*request_call_function_)(&call->context_, &call->request_, + &call->response_writer_, cq_.get(), cq_.get(), + call); + return call; + } + + private: + /// The gRPC-generated `AsyncService`. + AsyncService &service_; + + /// Pointer to the `AsyncService::RequestMethod` function. + RequestCallFunction request_call_function_; + + /// The service handler that handles the request. + ServiceHandler &service_handler_; + + /// Pointer to the service handler function. + HandleRequestFunction handle_request_function_; + + /// The `CompletionQueue`. + const std::unique_ptr &cq_; +}; + +} // namespace rpc +} // namespace ray + +#endif diff --git a/src/ray/rpc/util.h b/src/ray/rpc/util.h new file mode 100644 index 000000000..6ecc6c3c4 --- /dev/null +++ b/src/ray/rpc/util.h @@ -0,0 +1,33 @@ +#ifndef RAY_RPC_UTIL_H +#define RAY_RPC_UTIL_H + +#include + +#include "ray/common/status.h" + +namespace ray { +namespace rpc { + +/// Helper function that converts a ray status to gRPC status. +inline grpc::Status RayStatusToGrpcStatus(const Status &ray_status) { + if (ray_status.ok()) { + return grpc::Status::OK; + } else { + // TODO(hchen): Use more specific error code. + return grpc::Status(grpc::StatusCode::UNKNOWN, ray_status.message()); + } +} + +/// Helper function that converts a gRPC status to ray status. +inline Status GrpcStatusToRayStatus(const grpc::Status &grpc_status) { + if (grpc_status.ok()) { + return Status::OK(); + } else { + return Status::IOError(grpc_status.error_message()); + } +} + +} // namespace rpc +} // namespace ray + +#endif