diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 54df7e6b3..3da8de358 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -687,7 +687,7 @@ cdef class CoreWorker: def __cinit__(self, is_driver, store_socket, raylet_socket, JobID job_id, GcsClientOptions gcs_options, log_dir, - node_ip_address): + node_ip_address, node_manager_port): assert pyarrow is not None, ("Expected pyarrow to be imported from " "outside _raylet. See __init__.py for " "details.") @@ -697,8 +697,8 @@ cdef class CoreWorker: LANGUAGE_PYTHON, store_socket.encode("ascii"), raylet_socket.encode("ascii"), job_id.native(), gcs_options.native()[0], log_dir.encode("utf-8"), - node_ip_address.encode("utf-8"), task_execution_handler, - check_signals, exit_handler)) + node_ip_address.encode("utf-8"), node_manager_port, + task_execution_handler, check_signals, exit_handler)) def disconnect(self): with nogil: diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 073356385..abcc587cd 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -55,6 +55,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: const c_string &raylet_socket, const CJobID &job_id, const CGcsClientOptions &gcs_options, const c_string &log_dir, const c_string &node_ip_address, + int node_manager_port, CRayStatus ( CTaskType task_type, const CRayFunction &ray_function, diff --git a/python/ray/node.py b/python/ray/node.py index 19d9e1db9..828459e45 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -10,6 +10,7 @@ import json import os import logging import signal +import socket import sys import tempfile import threading @@ -117,7 +118,8 @@ class Node(object): # If user does not provide the socket name, get it from Redis. if (self._plasma_store_socket_name is None - or self._raylet_socket_name is None): + or self._raylet_socket_name is None + or self._ray_params.node_manager_port is None): # Get the address info of the processes to connect to # from Redis. address_info = ray.services.get_address_info_from_redis( @@ -127,6 +129,8 @@ class Node(object): self._plasma_store_socket_name = address_info[ "object_store_address"] self._raylet_socket_name = address_info["raylet_socket_name"] + self._ray_params.node_manager_port = address_info[ + "node_manager_port"] else: # If the user specified a socket name, use it. self._plasma_store_socket_name = self._prepare_socket_file( @@ -144,6 +148,16 @@ class Node(object): ray_params.include_java = ( ray.services.include_java_from_redis(redis_client)) + if head or not connect_only: + # We need to start a local raylet. + if (self._ray_params.node_manager_port is None + or self._ray_params.node_manager_port == 0): + # No port specified. Pick a random port for the raylet to use. + # NOTE: There is a possible but unlikely race condition where + # the port is bound by another process between now and when the + # raylet starts. + self._ray_params.node_manager_port = self._get_unused_port() + # Start processes. if head: self.start_head_processes() @@ -294,6 +308,11 @@ class Node(object): """Get the node's raylet socket name.""" return self._raylet_socket_name + @property + def node_manager_port(self): + """Get the node manager's port.""" + return self._ray_params.node_manager_port + @property def address_info(self): """Get a dictionary of addresses.""" @@ -390,6 +409,13 @@ class Node(object): log_stderr_file = open(log_stderr, "a", buffering=1) return log_stdout_file, log_stderr_file + def _get_unused_port(self): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("", 0)) + port = s.getsockname()[1] + s.close() + return port + def _prepare_socket_file(self, socket_path, default_prefix): """Prepare the socket file for raylet and plasma. @@ -508,6 +534,7 @@ class Node(object): process_info = ray.services.start_raylet( self._redis_address, self._node_ip_address, + self._ray_params.node_manager_port, self._raylet_socket_name, self._plasma_store_socket_name, self._ray_params.worker_path, @@ -515,7 +542,6 @@ class Node(object): self._session_dir, self.get_resource_spec(), self._ray_params.object_manager_port, - self._ray_params.node_manager_port, self._ray_params.redis_password, use_valgrind=use_valgrind, use_profiler=use_profiler, diff --git a/python/ray/services.py b/python/ray/services.py index 2f241d4b2..43b2811c9 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -153,6 +153,7 @@ def get_address_info_from_redis_helper(redis_address, return { "object_store_address": relevant_client["ObjectStoreSocketName"], "raylet_socket_name": relevant_client["RayletSocketName"], + "node_manager_port": relevant_client["NodeManagerPort"] } @@ -1045,6 +1046,7 @@ def start_dashboard(host, def start_raylet(redis_address, node_ip_address, + node_manager_port, raylet_name, plasma_store_name, worker_path, @@ -1052,7 +1054,6 @@ def start_raylet(redis_address, session_dir, resource_spec, object_manager_port=None, - node_manager_port=None, redis_password=None, use_valgrind=False, use_profiler=False, @@ -1068,6 +1069,8 @@ def start_raylet(redis_address, Args: redis_address (str): The address of the primary Redis server. node_ip_address (str): The IP address of this node. + node_manager_port(int): The port to use for the node manager. This must + not be 0. raylet_name (str): The name of the raylet socket to create. plasma_store_name (str): The name of the plasma store socket to connect to. @@ -1078,8 +1081,6 @@ def start_raylet(redis_address, resource_spec (ResourceSpec): Resources for this raylet. object_manager_port: The port to use for the object manager. If this is None, then the object manager will choose its own port. - node_manager_port: The port to use for the node manager. If this is - None, then the node manager will choose its own port. redis_password: The password to use when connecting to Redis. use_valgrind (bool): True if the raylet should be started inside of valgrind. If this is True, use_profiler must be False. @@ -1098,6 +1099,9 @@ def start_raylet(redis_address, Returns: ProcessInfo for the process that was started. """ + # The caller must provide a node manager port so that we can correctly + # populate the command to start a worker. + assert node_manager_port is not None and node_manager_port != 0 config = config or {} config_str = ",".join(["{},{}".format(*kv) for kv in config.items()]) @@ -1137,13 +1141,14 @@ def start_raylet(redis_address, # Create the command that the Raylet will use to start workers. start_worker_command = ("{} {} " "--node-ip-address={} " + "--node-manager-port={} " "--object-store-name={} " "--raylet-name={} " "--redis-address={} " "--temp-dir={}".format( sys.executable, worker_path, node_ip_address, - plasma_store_name, raylet_name, redis_address, - temp_dir)) + node_manager_port, plasma_store_name, + raylet_name, redis_address, temp_dir)) if redis_password: start_worker_command += " --redis-password {}".format(redis_password) @@ -1151,10 +1156,6 @@ def start_raylet(redis_address, # manager to choose its own port. if object_manager_port is None: object_manager_port = 0 - # If the node manager port is None, then use 0 to cause the node manager - # to choose its own port. - if node_manager_port is None: - node_manager_port = 0 if load_code_from_local: start_worker_command += " --load-code-from-local " diff --git a/python/ray/tests/cluster_utils.py b/python/ray/tests/cluster_utils.py index 09f527575..4bbaf275f 100644 --- a/python/ray/tests/cluster_utils.py +++ b/python/ray/tests/cluster_utils.py @@ -92,6 +92,8 @@ class Cluster(object): self.webui_url = self.head_node.webui_url else: ray_params.update_if_absent(redis_address=self.redis_address) + # Let grpc pick a port. + ray_params.update(node_manager_port=0) node = ray.node.Node( ray_params, head=False, diff --git a/python/ray/worker.py b/python/ray/worker.py index 6bca9e082..c14c0591f 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1215,6 +1215,7 @@ def connect(node, gcs_options, node.get_logs_dir_path(), node.node_ip_address, + node.node_manager_port, ) worker.raylet_client = ray._raylet.RayletClient(worker.core_worker) diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index 26c4030ea..83c9489b5 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -19,6 +19,11 @@ parser.add_argument( required=True, type=str, help="the ip address of the worker's node") +parser.add_argument( + "--node-manager-port", + required=True, + type=int, + help="the port of the worker's node") parser.add_argument( "--redis-address", required=True, @@ -74,6 +79,7 @@ if __name__ == "__main__": ray_params = RayParams( node_ip_address=args.node_ip_address, + node_manager_port=args.node_manager_port, redis_address=args.redis_address, redis_password=args.redis_password, plasma_store_socket_name=args.object_store_name, diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index b948ab8d5..af4ed02e9 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -60,6 +60,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, const std::string &store_socket, const std::string &raylet_socket, const JobID &job_id, const gcs::GcsClientOptions &gcs_options, const std::string &log_dir, const std::string &node_ip_address, + int node_manager_port, const TaskExecutionCallback &task_execution_callback, std::function check_signals, const std::function exit_handler) @@ -72,6 +73,7 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, heartbeat_timer_(io_service_), worker_server_(WorkerTypeString(worker_type), 0 /* let grpc choose a port */), gcs_client_(gcs_options), + client_call_manager_(io_service_), memory_store_(std::make_shared()), task_execution_service_work_(task_execution_service_), task_execution_callback_(task_execution_callback), @@ -117,8 +119,11 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, // connect to Raylet after a number of retries, this can be changed later // so that the worker (java/python .etc) can retrieve and handle the error // instead of crashing. + auto grpc_client = rpc::NodeManagerWorkerClient::make( + node_ip_address, node_manager_port, client_call_manager_); raylet_client_ = std::unique_ptr(new RayletClient( - raylet_socket, WorkerID::FromBinary(worker_context_.GetWorkerID().Binary()), + std::move(grpc_client), raylet_socket, + WorkerID::FromBinary(worker_context_.GetWorkerID().Binary()), (worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(), language_, worker_server_.GetPort())); // Unfortunately the raylet client has to be constructed after the receivers. @@ -489,7 +494,8 @@ Status CoreWorker::SubmitTaskToRaylet(const TaskSpecification &task_spec) { if (task_deps->size() > 0) { for (size_t i = 0; i < num_returns; i++) { - reference_counter_.SetDependencies(task_spec.ReturnId(i, TaskTransportType::RAYLET), task_deps); + reference_counter_.SetDependencies(task_spec.ReturnId(i, TaskTransportType::RAYLET), + task_deps); } } diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index a2869da37..2a8d6db32 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -17,6 +17,7 @@ #include "ray/core_worker/transport/raylet_transport.h" #include "ray/gcs/redis_gcs_client.h" #include "ray/raylet/raylet_client.h" +#include "ray/rpc/node_manager/node_manager_client.h" #include "ray/rpc/worker/worker_client.h" #include "ray/rpc/worker/worker_server.h" @@ -58,6 +59,7 @@ class CoreWorker { /// \param[in] log_dir Directory to write logs to. If this is empty, logs /// won't be written to a file. /// \param[in] node_ip_address IP address of the node. + /// \param[in] node_manager_port Port of the local raylet. /// \param[in] task_execution_callback Language worker callback to execute tasks. /// \parma[in] check_signals Language worker function to check for signals and handle /// them. If the function returns anything but StatusOK, any long-running @@ -70,7 +72,7 @@ class CoreWorker { const std::string &store_socket, const std::string &raylet_socket, const JobID &job_id, const gcs::GcsClientOptions &gcs_options, const std::string &log_dir, const std::string &node_ip_address, - const TaskExecutionCallback &task_execution_callback, + int node_manager_port, const TaskExecutionCallback &task_execution_callback, std::function check_signals = nullptr, std::function exit_handler = nullptr); @@ -454,6 +456,9 @@ class CoreWorker { // Client to the GCS shared by core worker interfaces. gcs::RedisGcsClient gcs_client_; + /// The `ClientCallManager` object that is shared by all `NodeManagerClient`s. + rpc::ClientCallManager client_call_manager_; + // Client to the raylet shared by core worker interfaces. std::unique_ptr raylet_client_; diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 9c8367bb9..c9c4c3bd5 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -25,12 +25,17 @@ #include "ray/thirdparty/hiredis/hiredis.h" #include "ray/util/test_util.h" -namespace ray { +namespace { std::string store_executable; std::string raylet_executable; +int node_manager_port = 0; std::string mock_worker_executable; +} // namespace + +namespace ray { + static void flushall_redis(void) { redisContext *context = redisConnect("127.0.0.1", 6379); freeReplyObject(redisCommand(context, "FLUSHALL")); @@ -92,8 +97,8 @@ class CoreWorkerTest : public ::testing::Test { // a task can be scheduled to the desired node. for (int i = 0; i < num_nodes; i++) { raylet_socket_names_[i] = - StartRaylet(raylet_store_socket_names_[i], "127.0.0.1", "127.0.0.1", - "\"CPU,4.0,resource" + std::to_string(i) + ",10\""); + StartRaylet(raylet_store_socket_names_[i], "127.0.0.1", node_manager_port + i, + "127.0.0.1", "\"CPU,4.0,resource" + std::to_string(i) + ",10\""); } } @@ -134,12 +139,12 @@ class CoreWorkerTest : public ::testing::Test { } std::string StartRaylet(std::string store_socket_name, std::string node_ip_address, - std::string redis_address, std::string resource) { + int port, std::string redis_address, std::string resource) { std::string raylet_socket_name = "/tmp/raylet" + ObjectID::FromRandom().Hex(); std::string ray_start_cmd = raylet_executable; ray_start_cmd.append(" --raylet_socket_name=" + raylet_socket_name) .append(" --store_socket_name=" + store_socket_name) - .append(" --object_manager_port=0 --node_manager_port=0") + .append(" --object_manager_port=0 --node_manager_port=" + std::to_string(port)) .append(" --node_ip_address=" + node_ip_address) .append(" --redis_address=" + redis_address) .append(" --redis_port=6379") @@ -147,7 +152,8 @@ class CoreWorkerTest : public ::testing::Test { .append(" --maximum_startup_concurrency=10") .append(" --static_resource_list=" + resource) .append(" --python_worker_command=\"" + mock_worker_executable + " " + - store_socket_name + " " + raylet_socket_name + "\"") + store_socket_name + " " + raylet_socket_name + " " + + std::to_string(port) + "\"") .append(" --config_list=initial_reconstruction_timeout_milliseconds,2000") .append(" & echo $! > " + raylet_socket_name + ".pid"); @@ -212,7 +218,7 @@ bool CoreWorkerTest::WaitForDirectCallActorState(CoreWorker &worker, void CoreWorkerTest::TestNormalTask(std::unordered_map &resources) { CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1", - nullptr); + node_manager_port, nullptr); // Test for tasks with by-value and by-ref args. { @@ -255,7 +261,7 @@ void CoreWorkerTest::TestActorTask(std::unordered_map &reso bool is_direct_call) { CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1", - nullptr); + node_manager_port, nullptr); auto actor_id = CreateActorHelper(driver, resources, is_direct_call, 1000); @@ -338,7 +344,7 @@ void CoreWorkerTest::TestActorReconstruction( std::unordered_map &resources, bool is_direct_call) { CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1", - nullptr); + node_manager_port, nullptr); // creating actor. auto actor_id = CreateActorHelper(driver, resources, is_direct_call, 1000); @@ -394,7 +400,7 @@ void CoreWorkerTest::TestActorFailure(std::unordered_map &r bool is_direct_call) { CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1", - nullptr); + node_manager_port, nullptr); // creating actor. auto actor_id = @@ -539,7 +545,7 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) { CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], JobID::FromInt(1), gcs_options_, "", - "127.0.0.1", nullptr); + "127.0.0.1", node_manager_port, nullptr); std::vector object_ids; // Create an actor. std::unordered_map resources; @@ -753,7 +759,8 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { TEST_F(SingleNodeTest, TestObjectInterface) { CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], - JobID::FromInt(1), gcs_options_, "", "127.0.0.1", nullptr); + JobID::FromInt(1), gcs_options_, "", "127.0.0.1", + node_manager_port, nullptr); uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; uint8_t array2[] = {10, 11, 12, 13, 14, 15}; @@ -824,11 +831,11 @@ TEST_F(SingleNodeTest, TestObjectInterface) { TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) { CoreWorker worker1(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1", - nullptr); + node_manager_port, nullptr); CoreWorker worker2(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[1], raylet_socket_names_[1], NextJobId(), gcs_options_, "", "127.0.0.1", - nullptr); + node_manager_port, nullptr); uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; uint8_t array2[] = {10, 11, 12, 13, 14, 15}; @@ -946,9 +953,10 @@ TEST_F(TwoNodeTest, TestDirectActorTaskCrossNodesFailure) { int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); - RAY_CHECK(argc == 4); - ray::store_executable = std::string(argv[1]); - ray::raylet_executable = std::string(argv[2]); - ray::mock_worker_executable = std::string(argv[3]); + RAY_CHECK(argc == 5); + store_executable = std::string(argv[1]); + raylet_executable = std::string(argv[2]); + node_manager_port = std::stoi(std::string(argv[3])); + mock_worker_executable = std::string(argv[4]); return RUN_ALL_TESTS(); } diff --git a/src/ray/core_worker/test/mock_worker.cc b/src/ray/core_worker/test/mock_worker.cc index c50d4187c..096e9f372 100644 --- a/src/ray/core_worker/test/mock_worker.cc +++ b/src/ray/core_worker/test/mock_worker.cc @@ -19,10 +19,10 @@ namespace ray { class MockWorker { public: MockWorker(const std::string &store_socket, const std::string &raylet_socket, - const gcs::GcsClientOptions &gcs_options) + int node_manager_port, const gcs::GcsClientOptions &gcs_options) : worker_(WorkerType::WORKER, Language::PYTHON, store_socket, raylet_socket, JobID::FromInt(1), gcs_options, /*log_dir=*/"", - /*node_id_address=*/"127.0.0.1", + /*node_id_address=*/"127.0.0.1", node_manager_port, std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7)) {} void StartExecutingTasks() { worker_.StartExecutingTasks(); } @@ -71,12 +71,13 @@ class MockWorker { } // namespace ray int main(int argc, char **argv) { - RAY_CHECK(argc == 3); + RAY_CHECK(argc == 4); auto store_socket = std::string(argv[1]); auto raylet_socket = std::string(argv[2]); + auto node_manager_port = std::stoi(std::string(argv[3])); ray::gcs::GcsClientOptions gcs_options("127.0.0.1", 6379, ""); - ray::MockWorker worker(store_socket, raylet_socket, gcs_options); + ray::MockWorker worker(store_socket, raylet_socket, node_manager_port, gcs_options); worker.StartExecutingTasks(); return 0; } diff --git a/src/ray/object_manager/test/object_manager_stress_test.cc b/src/ray/object_manager/test/object_manager_stress_test.cc index dd66258c3..1e564ba41 100644 --- a/src/ray/object_manager/test/object_manager_stress_test.cc +++ b/src/ray/object_manager/test/object_manager_stress_test.cc @@ -44,7 +44,7 @@ class MockServer { private: ray::Status RegisterGcs(boost::asio::io_service &io_service) { - auto object_manager_port = config_.object_manager_port; + auto object_manager_port = object_manager_.GetServerPort(); GcsNodeInfo node_info = gcs_client_->client_table().GetLocalClient(); node_info.set_node_manager_address("127.0.0.1"); node_info.set_node_manager_port(object_manager_port); @@ -110,7 +110,7 @@ class TestObjectManagerBase : public ::testing::Test { om_config_1.pull_timeout_ms = pull_timeout_ms; om_config_1.object_chunk_size = object_chunk_size; om_config_1.push_timeout_ms = push_timeout_ms; - om_config_1.object_manager_port = 12345; + om_config_1.object_manager_port = 0; om_config_1.rpc_service_threads_number = 3; server1.reset(new MockServer(main_service, om_config_1, gcs_client_1)); @@ -123,7 +123,7 @@ class TestObjectManagerBase : public ::testing::Test { om_config_2.pull_timeout_ms = pull_timeout_ms; om_config_2.object_chunk_size = object_chunk_size; om_config_2.push_timeout_ms = push_timeout_ms; - om_config_2.object_manager_port = 23456; + om_config_2.object_manager_port = 0; om_config_2.rpc_service_threads_number = 3; server2.reset(new MockServer(main_service, om_config_2, gcs_client_2)); diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index 70f10523f..2d0bea8fa 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -38,7 +38,7 @@ class MockServer { private: ray::Status RegisterGcs(boost::asio::io_service &io_service) { - auto object_manager_port = config_.object_manager_port; + auto object_manager_port = object_manager_.GetServerPort(); GcsNodeInfo node_info = gcs_client_->client_table().GetLocalClient(); node_info.set_node_manager_address("127.0.0.1"); node_info.set_node_manager_port(object_manager_port); @@ -102,7 +102,7 @@ class TestObjectManagerBase : public ::testing::Test { om_config_1.pull_timeout_ms = pull_timeout_ms; om_config_1.object_chunk_size = object_chunk_size; om_config_1.push_timeout_ms = push_timeout_ms; - om_config_1.object_manager_port = 12345; + om_config_1.object_manager_port = 0; om_config_1.rpc_service_threads_number = 3; server1.reset(new MockServer(main_service, om_config_1, gcs_client_1)); @@ -115,7 +115,7 @@ class TestObjectManagerBase : public ::testing::Test { om_config_2.pull_timeout_ms = pull_timeout_ms; om_config_2.object_chunk_size = object_chunk_size; om_config_2.push_timeout_ms = push_timeout_ms; - om_config_2.object_manager_port = 23456; + om_config_2.object_manager_port = 0; om_config_2.rpc_service_threads_number = 3; server2.reset(new MockServer(main_service, om_config_2, gcs_client_2)); diff --git a/src/ray/protobuf/node_manager.proto b/src/ray/protobuf/node_manager.proto index 6344f8010..d68a9afe1 100644 --- a/src/ray/protobuf/node_manager.proto +++ b/src/ray/protobuf/node_manager.proto @@ -4,6 +4,14 @@ package ray.rpc; import "src/ray/protobuf/common.proto"; +// Submit a task for execution. +message SubmitTaskRequest { + TaskSpec task_spec = 1; +} + +message SubmitTaskReply { +} + message ForwardTaskRequest { // The ID of the task to be forwarded. bytes task_id = 1; @@ -56,6 +64,8 @@ message NodeStatsReply { // Service for inter-node-manager communication. service NodeManagerService { + // Submit a task (from a local or remote worker) to the node manager. + rpc SubmitTask(SubmitTaskRequest) returns (SubmitTaskReply); // Forward a task and its uncommitted lineage to the remote node manager. rpc ForwardTask(ForwardTaskRequest) returns (ForwardTaskReply); // Get the current node stats. diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index f32ac148f..92907079d 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -5,12 +5,9 @@ namespace ray.protocol; enum MessageType:int { - // Task is submitted to the raylet. This is sent from a worker to a - // raylet. - SubmitTask = 1, // Notify the raylet that a task has finished. This is sent from a // worker to a raylet. - TaskDone, + TaskDone = 1, // Log a message to the event table. This is sent from a worker to a raylet. EventLogMessage, // Send an initial connection message to the raylet. This is sent @@ -94,10 +91,6 @@ table Task { task_execution_spec: TaskExecutionSpecification; } -table SubmitTaskRequest { - task_spec: string; -} - // This message describes a given resource that is reserved for a worker. table ResourceIdSetInfo { // The name of the resource. diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 485fb28be..5386bbe56 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -895,9 +895,6 @@ void NodeManager::ProcessClientMessage( // because it's already disconnected. return; } break; - case protocol::MessageType::SubmitTask: { - ProcessSubmitTaskMessage(message_data); - } break; case protocol::MessageType::SetResourceRequest: { ProcessSetResourceRequest(client, message_data); } break; @@ -1175,18 +1172,6 @@ void NodeManager::ProcessDisconnectClientMessage( // these can be leaked. } -void NodeManager::ProcessSubmitTaskMessage(const uint8_t *message_data) { - // Read the task submitted by the client. - auto fbs_message = flatbuffers::GetRoot(message_data); - rpc::Task task_message; - RAY_CHECK(task_message.mutable_task_spec()->ParseFromArray( - fbs_message->task_spec()->data(), fbs_message->task_spec()->size())); - - // Submit the task to the raylet. Since the task was submitted - // locally, there is no uncommitted lineage. - SubmitTask(Task(task_message), Lineage()); -} - void NodeManager::ProcessFetchOrReconstructMessage( const std::shared_ptr &client, const uint8_t *message_data) { auto message = flatbuffers::GetRoot(message_data); @@ -1390,6 +1375,18 @@ void NodeManager::ProcessReportActiveObjectIDs( unordered_set_from_flatbuf(*message->object_ids())); } +void NodeManager::HandleSubmitTask(const rpc::SubmitTaskRequest &request, + rpc::SubmitTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) { + rpc::Task task; + task.mutable_task_spec()->CopyFrom(request.task_spec()); + + // Submit the task to the raylet. Since the task was submitted + // locally, there is no uncommitted lineage. + SubmitTask(Task(task), Lineage()); + send_reply_callback(Status::OK(), nullptr, nullptr); +} + void NodeManager::HandleForwardTask(const rpc::ForwardTaskRequest &request, rpc::ForwardTaskReply *reply, rpc::SendReplyCallback send_reply_callback) { diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index c07340317..f44c55202 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -406,12 +406,6 @@ class NodeManager : public rpc::NodeManagerServiceHandler { const std::shared_ptr &client, bool intentional_disconnect = false); - /// Process client message of SubmitTask - /// - /// \param message_data A pointer to the message data. - /// \return Void. - void ProcessSubmitTaskMessage(const uint8_t *message_data); - /// Process client message of FetchOrReconstruct /// /// \param client The client that sent the message. @@ -495,6 +489,11 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \return void. void FinishAssignTask(const TaskID &task_id, Worker &worker, bool success); + /// Handle a `SubmitTask` request. + void HandleSubmitTask(const rpc::SubmitTaskRequest &request, + rpc::SubmitTaskReply *reply, + rpc::SendReplyCallback send_reply_callback) override; + /// Handle a `ForwardTask` request. void HandleForwardTask(const rpc::ForwardTaskRequest &request, rpc::ForwardTaskReply *reply, diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index d5fb2b4ff..6e1b34e14 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -201,10 +201,15 @@ ray::Status RayletConnection::AtomicRequestReply( return ReadMessage(reply_type, reply_message); } -RayletClient::RayletClient(const std::string &raylet_socket, const WorkerID &worker_id, +RayletClient::RayletClient(std::shared_ptr grpc_client, + const std::string &raylet_socket, const WorkerID &worker_id, bool is_worker, const JobID &job_id, const Language &language, int port) - : worker_id_(worker_id), is_worker_(is_worker), job_id_(job_id), language_(language) { + : grpc_client_(std::move(grpc_client)), + worker_id_(worker_id), + is_worker_(is_worker), + job_id_(job_id), + language_(language) { // For C++14, we could use std::make_unique conn_ = std::unique_ptr(new RayletConnection(raylet_socket, -1, -1)); @@ -220,11 +225,9 @@ RayletClient::RayletClient(const std::string &raylet_socket, const WorkerID &wor } ray::Status RayletClient::SubmitTask(const ray::TaskSpecification &task_spec) { - flatbuffers::FlatBufferBuilder fbb; - auto message = ray::protocol::CreateSubmitTaskRequest( - fbb, fbb.CreateString(task_spec.Serialize())); - fbb.Finish(message); - return conn_->WriteMessage(MessageType::SubmitTask, &fbb); + ray::rpc::SubmitTaskRequest request; + request.mutable_task_spec()->CopyFrom(task_spec.GetMessage()); + return grpc_client_->SubmitTask(request, /*callback=*/nullptr); } ray::Status RayletClient::TaskDone() { diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index 0869c501a..c9dac0e42 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -9,6 +9,7 @@ #include "ray/common/status.h" #include "ray/common/task/task_spec.h" +#include "ray/rpc/node_manager/node_manager_client.h" using ray::ActorCheckpointID; using ray::ActorID; @@ -66,13 +67,15 @@ class RayletClient { public: /// Connect to the raylet. /// + /// \param grpc_client gRPC client to the raylet. /// \param raylet_socket The name of the socket to use to connect to the raylet. /// \param worker_id A unique ID to represent the worker. /// \param is_worker Whether this client is a worker. If it is a worker, an /// additional message will be sent to register as one. /// \param job_id The ID of the driver. This is non-nil if the client is a driver. /// \return The connection information. - RayletClient(const std::string &raylet_socket, const WorkerID &worker_id, + RayletClient(std::shared_ptr grpc_client, + const std::string &raylet_socket, const WorkerID &worker_id, bool is_worker, const JobID &job_id, const Language &language, int port = -1); @@ -193,6 +196,9 @@ class RayletClient { const ResourceMappingType &GetResourceIDs() const { return resource_ids_; } private: + /// gRPC client to the raylet. Right now, this is only used for a couple + /// request types. + std::shared_ptr grpc_client_; const WorkerID worker_id_; const bool is_worker_; const JobID job_id_; diff --git a/src/ray/rpc/grpc_server.cc b/src/ray/rpc/grpc_server.cc index 31466dc17..ef6fea9c5 100644 --- a/src/ray/rpc/grpc_server.cc +++ b/src/ray/rpc/grpc_server.cc @@ -2,11 +2,38 @@ #include "src/ray/rpc/grpc_server.h" #include +namespace { + +bool PortNotInUse(int port) { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd == -1) { + return false; + } + struct sockaddr_in server_addr = {0}; + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = htonl(INADDR_ANY); + server_addr.sin_port = htons(port); + int err = bind(fd, (struct sockaddr *)&server_addr, sizeof(server_addr)); + close(fd); + return err == 0; +} + +} // namespace + namespace ray { namespace rpc { void GrpcServer::Run() { std::string server_address("0.0.0.0:" + std::to_string(port_)); + // Unfortunately, grpc will not return an error if the specified port is in + // use. There is a race condition here where two servers could check the same + // port, but only one would succeed in binding. + if (port_ > 0) { + RAY_CHECK(PortNotInUse(port_)) + << "Port " << port_ + << " specified by caller already in use. Try passing node_manager_port=... into " + "ray.init() to pick a specific port"; + } grpc::ServerBuilder builder; // TODO(hchen): Add options for authentication. diff --git a/src/ray/rpc/node_manager/node_manager_client.h b/src/ray/rpc/node_manager/node_manager_client.h index 8a26907ec..4d02c52bc 100644 --- a/src/ray/rpc/node_manager/node_manager_client.h +++ b/src/ray/rpc/node_manager/node_manager_client.h @@ -57,6 +57,53 @@ class NodeManagerClient { ClientCallManager &client_call_manager_; }; +/// Client used by workers for communicating with a node manager server. +class NodeManagerWorkerClient + : public std::enable_shared_from_this { + public: + /// Constructor. + /// + /// \param[in] address Address of the node manager server. + /// \param[in] port Port of the node manager server. + /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. + static std::shared_ptr make( + const std::string &address, const int port, + ClientCallManager &client_call_manager) { + auto instance = new NodeManagerWorkerClient(address, port, client_call_manager); + return std::shared_ptr(instance); + } + + /// Submit a task. + ray::Status SubmitTask(const SubmitTaskRequest &request, + const ClientCallback &callback) { + auto call = client_call_manager_ + .CreateCall( + *stub_, &NodeManagerService::Stub::PrepareAsyncSubmitTask, + request, callback); + return call->GetStatus(); + } + + private: + /// Constructor. + /// + /// \param[in] address Address of the node manager server. + /// \param[in] port Port of the node manager server. + /// \param[in] client_call_manager The `ClientCallManager` used for managing requests. + NodeManagerWorkerClient(const std::string &address, const int port, + ClientCallManager &client_call_manager) + : client_call_manager_(client_call_manager) { + std::shared_ptr channel = grpc::CreateChannel( + address + ":" + std::to_string(port), grpc::InsecureChannelCredentials()); + stub_ = NodeManagerService::NewStub(channel); + }; + + /// The gRPC-generated stub. + std::unique_ptr stub_; + + /// The `ClientCallManager` used for managing requests. + ClientCallManager &client_call_manager_; +}; + } // namespace rpc } // namespace ray diff --git a/src/ray/rpc/node_manager/node_manager_server.h b/src/ray/rpc/node_manager/node_manager_server.h index 08ebe9563..40a4d8022 100644 --- a/src/ray/rpc/node_manager/node_manager_server.h +++ b/src/ray/rpc/node_manager/node_manager_server.h @@ -13,24 +13,24 @@ namespace rpc { /// Interface of the `NodeManagerService`, see `src/ray/protobuf/node_manager.proto`. class NodeManagerServiceHandler { public: - /// Handle a `ForwardTask` request. - /// The implementation can handle this request asynchronously. When handling is done, - /// the `send_reply_callback` should be called. + /// Handlers. For all of the following handlers, the implementations can + /// handle the request asynchronously. When handling is done, the + /// `send_reply_callback` should be called. See + /// src/ray/rpc/node_manager/node_manager_client.h and + /// src/ray/protobuf/node_manager.proto for a description of the + /// functionality of each handler. /// /// \param[in] request The request message. /// \param[out] reply The reply message. /// \param[in] send_reply_callback The callback to be called when the request is done. + + virtual void HandleSubmitTask(const SubmitTaskRequest &request, SubmitTaskReply *reply, + SendReplyCallback send_reply_callback) = 0; + virtual void HandleForwardTask(const ForwardTaskRequest &request, ForwardTaskReply *reply, SendReplyCallback send_reply_callback) = 0; - /// Handle a `GetNodeStats` request. - /// The implementation can handle this request asynchronously. When handling is done, - /// the `send_reply_callback` should be called. - /// - /// \param[in] request The request message. - /// \param[out] reply The reply message. - /// \param[in] send_reply_callback The callback to be called when the request is done. virtual void HandleNodeStatsRequest(const NodeStatsRequest &request, NodeStatsReply *reply, SendReplyCallback send_reply_callback) = 0; @@ -55,6 +55,13 @@ class NodeManagerGrpcService : public GrpcService { std::vector, int>> *server_call_factories_and_concurrencies) override { // Initialize the factory for requests. + std::unique_ptr submit_task_call_factory( + new ServerCallFactoryImpl( + service_, &NodeManagerService::AsyncService::RequestSubmitTask, + service_handler_, &NodeManagerServiceHandler::HandleSubmitTask, cq, + main_service_)); + std::unique_ptr forward_task_call_factory( new ServerCallFactoryImpl( @@ -70,6 +77,8 @@ class NodeManagerGrpcService : public GrpcService { main_service_)); // Set accept concurrency. + server_call_factories_and_concurrencies->emplace_back( + std::move(submit_task_call_factory), 100); server_call_factories_and_concurrencies->emplace_back( std::move(forward_task_call_factory), 100); server_call_factories_and_concurrencies->emplace_back( diff --git a/src/ray/test/run_core_worker_tests.sh b/src/ray/test/run_core_worker_tests.sh index 7668b92ac..146090feb 100644 --- a/src/ray/test/run_core_worker_tests.sh +++ b/src/ray/test/run_core_worker_tests.sh @@ -2,6 +2,22 @@ # This needs to be run in the root directory. +# Try to find an unused port for raylet to use. +PORTS="2000 2001 2002 2003 2004 2005 2006 2007 2008 2009" +RAYLET_PORT=0 +for port in $PORTS; do + nc -z localhost $port + if [[ $? != 0 ]]; then + RAYLET_PORT=$port + break + fi +done + +if [[ $RAYLET_PORT == 0 ]]; then + echo "WARNING: Could not find unused port for raylet to use. Exiting without running tests." + exit +fi + # Cause the script to exit if a single command fails. set -e set -x @@ -38,7 +54,7 @@ sleep 2s bazel run //:redis-server -- --loglevel warning ${LOAD_MODULE_ARGS} --port 6380 & sleep 2s # Run tests. -./bazel-bin/core_worker_test $STORE_EXEC $RAYLET_EXEC $MOCK_WORKER_EXEC +./bazel-bin/core_worker_test $STORE_EXEC $RAYLET_EXEC $RAYLET_PORT $MOCK_WORKER_EXEC sleep 1s bazel run //:redis-cli -- -p 6379 shutdown bazel run //:redis-cli -- -p 6380 shutdown