From a76434ccde61bef853cb7655b22a5d67739a4913 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Wed, 20 May 2020 15:31:13 -0500 Subject: [PATCH] Add ability to specify worker and driver ports (#8071) --- .travis.yml | 4 + ci/keep_alive | 4 +- .../ray/api/test/BaseMultiLanguageTest.java | 2 + python/ray/cluster_utils.py | 2 + python/ray/node.py | 2 + python/ray/parameter.py | 34 +++++++ python/ray/scripts/scripts.py | 25 +++++- python/ray/services.py | 14 +++ python/ray/tests/conftest.py | 5 +- python/ray/tests/test_dynres.py | 6 +- python/ray/tests/test_multi_node.py | 17 +++- src/ray/core_worker/core_worker.cc | 28 ++++-- src/ray/core_worker/core_worker.h | 2 +- src/ray/core_worker/test/core_worker_test.cc | 2 + src/ray/raylet/format/node_manager.fbs | 9 ++ src/ray/raylet/main.cc | 8 ++ src/ray/raylet/node_manager.cc | 90 ++++++++++++------- src/ray/raylet/node_manager.h | 14 +++ src/ray/raylet/raylet_client.cc | 18 ++-- src/ray/raylet/raylet_client.h | 17 ++-- src/ray/raylet/worker.cc | 34 ++++--- src/ray/raylet/worker.h | 20 ++++- src/ray/raylet/worker_pool.cc | 63 ++++++++++--- src/ray/raylet/worker_pool.h | 41 +++++++-- src/ray/raylet/worker_pool_test.cc | 90 +++++++++---------- 25 files changed, 408 insertions(+), 143 deletions(-) diff --git a/.travis.yml b/.travis.yml index 689177906..0a6c777ea 100644 --- a/.travis.yml +++ b/.travis.yml @@ -25,6 +25,7 @@ matrix: - PYTHONWARNINGS=ignore - RAY_DEFAULT_BUILD=1 - RAY_CYTHON_EXAMPLES=1 + - RAY_USE_RANDOM_PORTS=1 install: - . ./ci/travis/ci.sh init RAY_CI_SERVE_AFFECTED,RAY_CI_TUNE_AFFECTED,RAY_CI_PYTHON_AFFECTED before_script: @@ -37,6 +38,7 @@ matrix: - PYTHONWARNINGS=ignore - RAY_DEFAULT_BUILD=1 - RAY_CYTHON_EXAMPLES=1 + - RAY_USE_RANDOM_PORTS=1 install: - . ./ci/travis/ci.sh init RAY_CI_SERVE_AFFECTED,RAY_CI_TUNE_AFFECTED,RAY_CI_PYTHON_AFFECTED before_script: @@ -62,6 +64,7 @@ matrix: - RAY_INSTALL_JAVA=1 - RAY_GCS_ACTOR_SERVICE_ENABLED=true - PYTHON=3.6 PYTHONWARNINGS=ignore + - RAY_USE_RANDOM_PORTS=1 install: - . ./ci/travis/ci.sh init RAY_CI_STREAMING_PYTHON_AFFECTED,RAY_CI_STREAMING_JAVA_AFFECTED before_script: @@ -96,6 +99,7 @@ matrix: - RAY_INSTALL_JAVA=1 - RAY_GCS_SERVICE_ENABLED=false - RAY_CYTHON_EXAMPLES=1 + - RAY_USE_RANDOM_PORTS=1 install: - . ./ci/travis/ci.sh init RAY_CI_ONLY_RLLIB_AFFECTED before_script: diff --git a/ci/keep_alive b/ci/keep_alive index 313507fdc..a8198fd97 100755 --- a/ci/keep_alive +++ b/ci/keep_alive @@ -5,8 +5,8 @@ PID=$$ # Print output to avoid travis killing us watchdog() { - for i in `seq 5 5 150`; do - sleep 300 + for i in `seq 2 2 150`; do + sleep 120 echo "(running, ${i}m total)" done echo "TIMED OUT" diff --git a/java/test/src/main/java/io/ray/api/test/BaseMultiLanguageTest.java b/java/test/src/main/java/io/ray/api/test/BaseMultiLanguageTest.java index a830bc37d..13298e5e5 100644 --- a/java/test/src/main/java/io/ray/api/test/BaseMultiLanguageTest.java +++ b/java/test/src/main/java/io/ray/api/test/BaseMultiLanguageTest.java @@ -85,6 +85,8 @@ public abstract class BaseMultiLanguageTest { "start", "--head", "--redis-port=6379", + "--min-worker-port=0", + "--max-worker-port=0", String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME), String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME), String.format("--node-manager-port=%s", nodeManagerPort), diff --git a/python/ray/cluster_utils.py b/python/ray/cluster_utils.py index e34482190..9bbced5d2 100644 --- a/python/ray/cluster_utils.py +++ b/python/ray/cluster_utils.py @@ -77,6 +77,8 @@ class Cluster: "num_cpus": 1, "num_gpus": 0, "object_store_memory": 150 * 1024 * 1024, # 150 MiB + "min_worker_port": 0, + "max_worker_port": 0, } if "_internal_config" in node_args: node_args["_internal_config"] = json.loads( diff --git a/python/ray/node.py b/python/ray/node.py index 12fbd5f1e..5d2483a64 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -586,6 +586,8 @@ class Node: self._temp_dir, self._session_dir, self.get_resource_spec(), + self._ray_params.min_worker_port, + self._ray_params.max_worker_port, self._ray_params.object_manager_port, self._ray_params.redis_password, use_valgrind=use_valgrind, diff --git a/python/ray/parameter.py b/python/ray/parameter.py index 031b337bd..2efeb1ac0 100644 --- a/python/ray/parameter.py +++ b/python/ray/parameter.py @@ -1,4 +1,5 @@ import logging +import os import numpy as np @@ -35,6 +36,10 @@ class RayParams: node_ip_address (str): The IP address of the node that we are on. raylet_ip_address (str): The IP address of the raylet that this node connects to. + min_worker_port (int): The lowest port number that workers will bind + on. If not set or set to 0, random ports will be chosen. + max_worker_port (int): The highest port number that workers will bind + on. If set, min_worker_port must also be set. object_id_seed (int): Used to seed the deterministic generation of object IDs. The same value can be used across multiple runs of the same job in order to generate the object IDs in a consistent @@ -98,6 +103,8 @@ class RayParams: node_manager_port=None, node_ip_address=None, raylet_ip_address=None, + min_worker_port=None, + max_worker_port=None, object_id_seed=None, driver_mode=None, redirect_worker_output=None, @@ -135,6 +142,8 @@ class RayParams: self.node_manager_port = node_manager_port self.node_ip_address = node_ip_address self.raylet_ip_address = raylet_ip_address + self.min_worker_port = min_worker_port + self.max_worker_port = max_worker_port self.driver_mode = driver_mode self.redirect_worker_output = redirect_worker_output self.redirect_output = redirect_output @@ -189,6 +198,31 @@ class RayParams: self._check_usage() def _check_usage(self): + # Used primarily for testing. + if os.environ.get("RAY_USE_RANDOM_PORTS", False): + if self.min_worker_port is None and self.min_worker_port is None: + self.min_worker_port = 0 + self.max_worker_port = 0 + + if self.min_worker_port is not None: + if self.min_worker_port != 0 and (self.min_worker_port < 1024 + or self.min_worker_port > 65535): + raise ValueError("min_worker_port must be 0 or an integer " + "between 1024 and 65535.") + + if self.max_worker_port is not None: + if self.min_worker_port is None: + raise ValueError("If max_worker_port is set, min_worker_port " + "must also be set.") + elif self.max_worker_port != 0: + if self.max_worker_port < 1024 or self.max_worker_port > 65535: + raise ValueError( + "max_worker_port must be 0 or an integer between " + "1024 and 65535.") + elif self.max_worker_port <= self.min_worker_port: + raise ValueError("max_worker_port must be higher than " + "min_worker_port.") + if self.resources is not None: assert "CPU" not in self.resources, ( "'CPU' should not be included in the resource dictionary. Use " diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 270ff9570..772d76ff5 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -170,6 +170,20 @@ def dashboard(cluster_config_file, cluster_name, port): required=False, type=int, help="the port to use for starting the node manager") +@click.option( + "--min-worker-port", + required=False, + type=int, + default=10000, + help="the lowest port number that workers will bind on. If not set, " + "random ports will be chosen.") +@click.option( + "--max-worker-port", + required=False, + type=int, + default=10999, + help="the highest port number that workers will bind on. If set, " + "'--min-worker-port' must also be set.") @click.option( "--memory", required=False, @@ -289,10 +303,11 @@ def dashboard(cluster_config_file, cluster_name, port): help="Specify whether load code from local file or GCS serialization.") def start(node_ip_address, redis_address, address, redis_port, port, num_redis_shards, redis_max_clients, redis_password, - redis_shard_ports, object_manager_port, node_manager_port, memory, - object_store_memory, redis_max_memory, num_cpus, num_gpus, resources, - head, include_webui, webui_host, block, plasma_directory, huge_pages, - autoscaling_config, no_redirect_worker_output, no_redirect_output, + redis_shard_ports, object_manager_port, node_manager_port, + min_worker_port, max_worker_port, memory, object_store_memory, + redis_max_memory, num_cpus, num_gpus, resources, head, include_webui, + webui_host, block, plasma_directory, huge_pages, autoscaling_config, + no_redirect_worker_output, no_redirect_output, plasma_store_socket_name, raylet_socket_name, temp_dir, include_java, java_worker_options, load_code_from_local, internal_config): """Start Ray processes manually on the local machine.""" @@ -327,6 +342,8 @@ def start(node_ip_address, redis_address, address, redis_port, port, redirect_output = None if not no_redirect_output else True ray_params = ray.parameter.RayParams( node_ip_address=node_ip_address, + min_worker_port=min_worker_port, + max_worker_port=max_worker_port, object_manager_port=object_manager_port, node_manager_port=node_manager_port, memory=memory, diff --git a/python/ray/services.py b/python/ray/services.py index 6136bb959..1e120321a 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1219,6 +1219,8 @@ def start_raylet(redis_address, temp_dir, session_dir, resource_spec, + min_worker_port=None, + max_worker_port=None, object_manager_port=None, redis_password=None, use_valgrind=False, @@ -1247,6 +1249,10 @@ def start_raylet(redis_address, resource_spec (ResourceSpec): Resources for this raylet. object_manager_port: The port to use for the object manager. If this is None, then the object manager will choose its own port. + min_worker_port (int): The lowest port number that workers will bind + on. If not set, random ports will be chosen. + max_worker_port (int): The highest port number that workers will bind + on. If set, min_worker_port must also be set. redis_password: The password to use when connecting to Redis. use_valgrind (bool): True if the raylet should be started inside of valgrind. If this is True, use_profiler must be False. @@ -1324,6 +1330,12 @@ def start_raylet(redis_address, if object_manager_port is None: object_manager_port = 0 + if min_worker_port is None: + min_worker_port = 0 + + if max_worker_port is None: + max_worker_port = 0 + if load_code_from_local: start_worker_command += ["--load-code-from-local"] @@ -1332,6 +1344,8 @@ def start_raylet(redis_address, "--raylet_socket_name={}".format(raylet_name), "--store_socket_name={}".format(plasma_store_name), "--object_manager_port={}".format(object_manager_port), + "--min_worker_port={}".format(min_worker_port), + "--max_worker_port={}".format(max_worker_port), "--node_manager_port={}".format(node_manager_port), "--node_ip_address={}".format(node_ip_address), "--redis_address={}".format(gcs_ip_address), diff --git a/python/ray/tests/conftest.py b/python/ray/tests/conftest.py index e06c655f3..c5e636476 100644 --- a/python/ray/tests/conftest.py +++ b/python/ray/tests/conftest.py @@ -153,7 +153,10 @@ def ray_start_object_store_memory(request): @pytest.fixture def call_ray_start(request): - parameter = getattr(request, "param", "ray start --head --num-cpus=1") + parameter = getattr( + request, "param", + "ray start --head --num-cpus=1 --min-worker-port=0 --max-worker-port=0" + ) command_args = parameter.split(" ") out = ray.utils.decode( subprocess.check_output(command_args, stderr=subprocess.STDOUT)) diff --git a/python/ray/tests/test_dynres.py b/python/ray/tests/test_dynres.py index a6cf2bda1..38fcaff80 100644 --- a/python/ray/tests/test_dynres.py +++ b/python/ray/tests/test_dynres.py @@ -43,8 +43,10 @@ def test_dynamic_res_deletion(shutdown_only): available_res = ray.available_resources() cluster_res = ray.cluster_resources() - assert res_name not in available_res - assert res_name not in cluster_res + def check_resources(): + return res_name not in available_res and res_name not in cluster_res + + ray.test_utils.wait_for_condition(check_resources) def test_dynamic_res_infeasible_rescheduling(ray_start_regular): diff --git a/python/ray/tests/test_multi_node.py b/python/ray/tests/test_multi_node.py index 605ba167a..5979161dd 100644 --- a/python/ray/tests/test_multi_node.py +++ b/python/ray/tests/test_multi_node.py @@ -285,7 +285,10 @@ ray.get([a.log.remote(), f.remote()]) @pytest.mark.parametrize( - "call_ray_start", ["ray start --head --num-cpus=1 --num-gpus=1"], + "call_ray_start", [ + "ray start --head --num-cpus=1 --num-gpus=1 " + + "--min-worker-port=0 --max-worker-port=0" + ], indirect=True) def test_drivers_release_resources(call_ray_start): address = call_ray_start @@ -334,6 +337,7 @@ print("success") print(output_line) if output_line == "success": return + time.sleep(1) raise RayTestTimeoutException( "Timed out waiting for process to print success.") @@ -376,6 +380,13 @@ def test_calling_start_ray_head(call_ray_stop_only): ]) subprocess.check_output(["ray", "stop"]) + # Test starting Ray with the worker port range specified. + subprocess.check_output([ + "ray", "start", "--head", "--min-worker-port", "50000", + "--max-worker-port", "51000" + ]) + subprocess.check_output(["ray", "stop"]) + # Test starting Ray with the number of CPUs specified. subprocess.check_output(["ray", "start", "--head", "--num-cpus", "2"]) subprocess.check_output(["ray", "stop"]) @@ -419,7 +430,7 @@ def test_calling_start_ray_head(call_ray_stop_only): assert blocked.returncode is None kill_process_by_name("raylet") - wait_for_children_of_pid_to_exit(blocked.pid, timeout=120) + wait_for_children_of_pid_to_exit(blocked.pid, timeout=30) blocked.wait() assert blocked.returncode != 0, "ray start shouldn't return 0 on bad exit" @@ -431,7 +442,7 @@ def test_calling_start_ray_head(call_ray_stop_only): wait_for_children_of_pid(blocked.pid, num_children=7, timeout=30) blocked.terminate() - wait_for_children_of_pid_to_exit(blocked.pid, timeout=120) + wait_for_children_of_pid_to_exit(blocked.pid, timeout=30) blocked.wait() assert blocked.returncode != 0, "ray start shouldn't return 0 on bad exit" diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index c806e6744..4d3f935c5 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -244,8 +244,6 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ client_call_manager_(new rpc::ClientCallManager(io_service_)), death_check_timer_(io_service_), internal_timer_(io_service_), - core_worker_server_(WorkerTypeString(options_.worker_type), - 0 /* let grpc choose a port */), task_queue_length_(0), num_executed_tasks_(0), task_execution_service_work_(task_execution_service_), @@ -266,10 +264,6 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ [this] { return local_raylet_client_->TaskDone(); })); } - // Start RPC server after all the task receivers are properly initialized. - core_worker_server_.RegisterService(grpc_service_); - core_worker_server_.Run(); - // Initialize raylet client. // NOTE(edoakes): the core_worker_server_ must be running before registering with // the raylet, as the raylet will start sending some RPC messages immediately. @@ -280,21 +274,37 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ auto grpc_client = rpc::NodeManagerWorkerClient::make( options_.raylet_ip_address, options_.node_manager_port, *client_call_manager_); ClientID local_raylet_id; + int assigned_port; std::unordered_map internal_config; local_raylet_client_ = std::shared_ptr(new raylet::RayletClient( io_service_, std::move(grpc_client), options_.raylet_socket, GetWorkerID(), (options_.worker_type == ray::WorkerType::WORKER), - worker_context_.GetCurrentJobID(), options_.language, &local_raylet_id, - &internal_config, options_.node_ip_address, core_worker_server_.GetPort())); + worker_context_.GetCurrentJobID(), options_.language, options_.node_ip_address, + &local_raylet_id, &assigned_port, &internal_config)); connected_ = true; + RAY_CHECK(assigned_port != -1) + << "Failed to allocate a port for the worker. Please specify a wider port range " + "using the '--min-worker-port' and '--max-worker-port' arguments to 'ray " + "start'."; + // NOTE(edoakes): any initialization depending on RayConfig must happen after this line. RayConfig::instance().initialize(internal_config); + // Start RPC server after all the task receivers are properly initialized and we have + // our assigned port from the raylet. + core_worker_server_ = std::unique_ptr( + new rpc::GrpcServer(WorkerTypeString(options_.worker_type), assigned_port)); + core_worker_server_->RegisterService(grpc_service_); + core_worker_server_->Run(); + + // Tell the raylet the port that we are listening on. + RAY_CHECK_OK(local_raylet_client_->AnnounceWorkerPort(core_worker_server_->GetPort())); + // Set our own address. RAY_CHECK(!local_raylet_id.IsNil()); rpc_address_.set_ip_address(options_.node_ip_address); - rpc_address_.set_port(core_worker_server_.GetPort()); + rpc_address_.set_port(core_worker_server_->GetPort()); rpc_address_.set_raylet_id(local_raylet_id.Binary()); rpc_address_.set_worker_id(worker_context_.GetWorkerID().Binary()); RAY_LOG(INFO) << "Initializing worker at address: " << rpc_address_.ip_address() << ":" diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index db9d1ecc1..2afdeb9b1 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -946,7 +946,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { boost::asio::steady_timer internal_timer_; /// RPC server used to receive tasks to execute. - rpc::GrpcServer core_worker_server_; + std::unique_ptr core_worker_server_; /// Address of our RPC server. rpc::Address rpc_address_; diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 712621be0..7e838b596 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -185,6 +185,8 @@ class CoreWorkerTest : public RedisServiceManagerForTest { .append(" --node_ip_address=" + node_ip_address) .append(" --redis_address=" + redis_address) .append(" --redis_port=6379") + .append(" --min-worker-port=0") + .append(" --max-worker-port=0") .append(" --num_initial_workers=1") .append(" --maximum_startup_concurrency=10") .append(" --static_resource_list=" + resource) diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 9f13383cf..266d9d141 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -33,6 +33,8 @@ enum MessageType:int { // Send a reply confirming the successful registration of a worker or driver. // This is sent from the raylet to a worker or driver. RegisterClientReply, + // Send the worker's gRPC port to the raylet. + AnnounceWorkerPort, // Notify the raylet that this client is disconnecting unexpectedly. // This is sent from a worker to a raylet. DisconnectClient, @@ -160,12 +162,19 @@ table RegisterClientRequest { table RegisterClientReply { // GCS ClientID of the local node manager. raylet_id: string; + // Port that this worker should listen on. + port: int; // Keys for internal config options. internal_config_keys: [string]; // Values for internal config options corresponding to keys above. internal_config_values: [string]; } +table AnnounceWorkerPort { + // Port that this worker is listening on. + port: int; +} + table RegisterNodeManagerRequest { // GCS ClientID of the connecting node manager. client_id: string; diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index e282c9527..24ddf0c78 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -31,6 +31,10 @@ DEFINE_int32(node_manager_port, -1, "The port of node manager."); DEFINE_string(node_ip_address, "", "The ip address of this node."); DEFINE_string(redis_address, "", "The ip address of redis server."); DEFINE_int32(redis_port, -1, "The port of redis server."); +DEFINE_int32(min_worker_port, 0, + "The lowest port that workers' gRPC servers will bind on."); +DEFINE_int32(max_worker_port, 0, + "The highest port that workers' gRPC servers will bind on."); DEFINE_int32(num_initial_workers, 0, "Number of initial workers."); DEFINE_int32(maximum_startup_concurrency, 1, "Maximum startup concurrency"); DEFINE_string(static_resource_list, "", "The static resource list of this node."); @@ -62,6 +66,8 @@ int main(int argc, char *argv[]) { const std::string node_ip_address = FLAGS_node_ip_address; const std::string redis_address = FLAGS_redis_address; const int redis_port = static_cast(FLAGS_redis_port); + const int min_worker_port = static_cast(FLAGS_min_worker_port); + const int max_worker_port = static_cast(FLAGS_max_worker_port); const int num_initial_workers = static_cast(FLAGS_num_initial_workers); const int maximum_startup_concurrency = static_cast(FLAGS_maximum_startup_concurrency); @@ -121,6 +127,8 @@ int main(int argc, char *argv[]) { node_manager_config.node_manager_port = node_manager_port; node_manager_config.num_initial_workers = num_initial_workers; node_manager_config.maximum_startup_concurrency = maximum_startup_concurrency; + node_manager_config.min_worker_port = min_worker_port; + node_manager_config.max_worker_port = max_worker_port; if (!python_worker_command.empty()) { node_manager_config.worker_commands.emplace( diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 3e53cc3c1..bd518e289 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -141,7 +141,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, local_available_resources_(config.resource_config), worker_pool_( io_service, config.num_initial_workers, config.maximum_startup_concurrency, - gcs_client_, config.worker_commands, config.raylet_config, + config.min_worker_port, config.max_worker_port, gcs_client_, + config.worker_commands, config.raylet_config, /*starting_worker_timeout_callback=*/ [this]() { this->DispatchTasks(this->local_queues_.GetReadyTasksByClass()); }), scheduling_policy_(local_queues_), @@ -398,8 +399,8 @@ void NodeManager::Heartbeat() { } void NodeManager::DoLocalGC() { - auto all_workers = worker_pool_.GetAllWorkers(); - for (const auto &driver : worker_pool_.GetAllDrivers()) { + auto all_workers = worker_pool_.GetAllRegisteredWorkers(); + for (const auto &driver : worker_pool_.GetAllRegisteredDrivers()) { all_workers.push_back(driver); } RAY_LOG(WARNING) << "Sending local GC request to " << all_workers.size() << " workers."; @@ -995,6 +996,9 @@ void NodeManager::ProcessClientMessage(const std::shared_ptr & case protocol::MessageType::RegisterClientRequest: { ProcessRegisterClientRequestMessage(client, message_data); } break; + case protocol::MessageType::AnnounceWorkerPort: { + ProcessAnnounceWorkerPortMessage(client, message_data); + } break; case protocol::MessageType::TaskDone: { HandleWorkerAvailable(client); } break; @@ -1086,37 +1090,21 @@ void NodeManager::ProcessClientMessage(const std::shared_ptr & void NodeManager::ProcessRegisterClientRequestMessage( const std::shared_ptr &client, const uint8_t *message_data) { client->Register(); - flatbuffers::FlatBufferBuilder fbb; - std::vector internal_config_keys; - std::vector internal_config_values; - for (auto kv : initial_config_.raylet_config) { - internal_config_keys.push_back(kv.first); - internal_config_values.push_back(kv.second); - } - auto reply = ray::protocol::CreateRegisterClientReply( - fbb, to_flatbuf(fbb, self_node_id_), - string_vec_to_flatbuf(fbb, internal_config_keys), - string_vec_to_flatbuf(fbb, internal_config_values)); - fbb.Finish(reply); - client->WriteMessageAsync( - static_cast(protocol::MessageType::RegisterClientReply), fbb.GetSize(), - fbb.GetBufferPointer(), [this, client](const ray::Status &status) { - if (!status.ok()) { - ProcessDisconnectClientMessage(client); - } - }); auto message = flatbuffers::GetRoot(message_data); Language language = static_cast(message->language()); WorkerID worker_id = from_flatbuf(*message->worker_id()); pid_t pid = message->worker_pid(); std::string worker_ip_address = string_from_flatbuf(*message->ip_address()); - auto worker = std::make_shared(worker_id, language, worker_ip_address, - message->port(), client, client_call_manager_); + auto worker = std::make_shared(worker_id, language, worker_ip_address, client, + client_call_manager_); + + int assigned_port; if (message->is_worker()) { // Register the new worker. - if (worker_pool_.RegisterWorker(worker, pid).ok()) { - HandleWorkerAvailable(worker->Connection()); + if (!worker_pool_.RegisterWorker(worker, pid, &assigned_port).ok()) { + // Return -1 to signal to the worker that registration failed. + assigned_port = -1; } } else { // Register the new driver. @@ -1127,14 +1115,56 @@ void NodeManager::ProcessRegisterClientRequestMessage( const TaskID driver_task_id = TaskID::ComputeDriverTaskId(worker_id); worker->AssignTaskId(driver_task_id); worker->AssignJobId(job_id); - Status status = worker_pool_.RegisterDriver(worker); + Status status = worker_pool_.RegisterDriver(worker, &assigned_port); if (status.ok()) { local_queues_.AddDriverTaskId(driver_task_id); auto job_data_ptr = gcs::CreateJobTableData( job_id, /*is_dead*/ false, std::time(nullptr), worker_ip_address, pid); RAY_CHECK_OK(gcs_client_->Jobs().AsyncAdd(job_data_ptr, nullptr)); + } else { + // Return -1 to signal to the worker that registration failed. + assigned_port = -1; } } + + flatbuffers::FlatBufferBuilder fbb; + std::vector internal_config_keys; + std::vector internal_config_values; + for (auto kv : initial_config_.raylet_config) { + internal_config_keys.push_back(kv.first); + internal_config_values.push_back(kv.second); + } + auto reply = ray::protocol::CreateRegisterClientReply( + fbb, to_flatbuf(fbb, self_node_id_), assigned_port, + string_vec_to_flatbuf(fbb, internal_config_keys), + string_vec_to_flatbuf(fbb, internal_config_values)); + fbb.Finish(reply); + client->WriteMessageAsync( + static_cast(protocol::MessageType::RegisterClientReply), fbb.GetSize(), + fbb.GetBufferPointer(), [this, client](const ray::Status &status) { + if (!status.ok()) { + ProcessDisconnectClientMessage(client); + } + }); +} + +void NodeManager::ProcessAnnounceWorkerPortMessage( + const std::shared_ptr &client, const uint8_t *message_data) { + bool is_worker = true; + std::shared_ptr worker = worker_pool_.GetRegisteredWorker(client); + if (worker == nullptr) { + is_worker = false; + worker = worker_pool_.GetRegisteredDriver(client); + } + RAY_CHECK(worker != nullptr) << "No worker exists for CoreWorker with client: " + << client->DebugString(); + + auto message = flatbuffers::GetRoot(message_data); + int port = message->port(); + worker->Connect(port); + if (is_worker) { + HandleWorkerAvailable(worker->Connection()); + } } void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_local, @@ -2662,7 +2692,7 @@ std::shared_ptr NodeManager::CreateActorTableDataFromCreationTas } void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { - RAY_LOG(INFO) << "Finishing assigned actor task"; + RAY_LOG(DEBUG) << "Finishing assigned actor task"; ActorID actor_id; TaskID caller_id; const TaskSpecification task_spec = task.GetTaskSpecification(); @@ -3520,9 +3550,9 @@ void NodeManager::HandleGetNodeStats(const rpc::GetNodeStatsRequest &node_stats_ // and return the information from HandleNodesStatsRequest. The caller of // HandleGetNodeStats should set a timeout so that the rpc finishes even if not all // workers have replied. - auto all_workers = worker_pool_.GetAllWorkers(); + auto all_workers = worker_pool_.GetAllRegisteredWorkers(); absl::flat_hash_set driver_ids; - for (auto driver : worker_pool_.GetAllDrivers()) { + for (auto driver : worker_pool_.GetAllRegisteredDrivers()) { all_workers.push_back(driver); driver_ids.insert(driver->WorkerId()); } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 11e39bf5f..558df2bee 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -58,6 +58,12 @@ struct NodeManagerConfig { /// The port to use for listening to incoming connections. If this is 0 then /// the node manager will choose its own port. int node_manager_port; + /// The lowest port number that workers started will bind on. + /// If this is set to 0, workers will bind on random ports. + int min_worker_port; + /// The highest port number that workers started will bind on. + /// If this is not set to 0, min_worker_port must also not be set to 0. + int max_worker_port; /// The initial number of workers to create. int num_initial_workers; /// The maximum number of workers that can be started concurrently by a @@ -453,6 +459,14 @@ class NodeManager : public rpc::NodeManagerServiceHandler { void ProcessRegisterClientRequestMessage( const std::shared_ptr &client, const uint8_t *message_data); + /// Process client message of AnnounceWorkerPort + /// + /// \param client The client that sent the message. + /// \param message_data A pointer to the message data. + /// \return Void. + void ProcessAnnounceWorkerPortMessage(const std::shared_ptr &client, + const uint8_t *message_data); + /// Handle the case that a worker is available. /// /// \param client The connection for the worker. diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 7c1e0eae0..611791d4a 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -163,11 +163,11 @@ raylet::RayletClient::RayletClient( raylet::RayletClient::RayletClient( boost::asio::io_service &io_service, - std::shared_ptr grpc_client, + std::shared_ptr grpc_client, const std::string &raylet_socket, const WorkerID &worker_id, bool is_worker, - const JobID &job_id, const Language &language, ClientID *raylet_id, - std::unordered_map *internal_config, - const std::string &ip_address, int port) + const JobID &job_id, const Language &language, const std::string &ip_address, + ClientID *raylet_id, int *port, + std::unordered_map *internal_config) : grpc_client_(std::move(grpc_client)), worker_id_(worker_id), job_id_(job_id) { // For C++14, we could use std::make_unique conn_ = std::unique_ptr( @@ -176,7 +176,7 @@ raylet::RayletClient::RayletClient( flatbuffers::FlatBufferBuilder fbb; auto message = protocol::CreateRegisterClientRequest( fbb, is_worker, to_flatbuf(fbb, worker_id), getpid(), to_flatbuf(fbb, job_id), - language, fbb.CreateString(ip_address), port); + language, fbb.CreateString(ip_address)); fbb.Finish(message); // Register the process ID with the raylet. // NOTE(swang): If raylet exits and we are registered as a worker, we will get killed. @@ -186,6 +186,7 @@ raylet::RayletClient::RayletClient( RAY_CHECK_OK_PREPEND(status, "[RayletClient] Unable to register worker with raylet."); auto reply_message = flatbuffers::GetRoot(reply.get()); *raylet_id = ClientID::FromBinary(reply_message->raylet_id()->str()); + *port = reply_message->port(); RAY_CHECK(internal_config); auto keys = reply_message->internal_config_keys(); @@ -196,6 +197,13 @@ raylet::RayletClient::RayletClient( } } +Status raylet::RayletClient::AnnounceWorkerPort(int port) { + flatbuffers::FlatBufferBuilder fbb; + auto message = protocol::CreateAnnounceWorkerPort(fbb, port); + fbb.Finish(message); + return conn_->WriteMessage(MessageType::AnnounceWorkerPort, &fbb); +} + Status raylet::RayletClient::SubmitTask(const TaskSpecification &task_spec) { flatbuffers::FlatBufferBuilder fbb; auto message = diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index 10fe98b6b..2203cf4a2 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -153,19 +153,18 @@ class RayletClient : public PinObjectsInterface, /// additional message will be sent to register as one. /// \param job_id The ID of the driver. This is non-nil if the client is a driver. /// \param language Language of the worker. + /// \param ip_address The IP address of the worker. /// \param raylet_id This will be populated with the local raylet's ClientID. /// \param internal_config This will be populated with internal config parameters /// provided by the raylet. - /// \param ip_address The IP address of the worker. - /// \param port The port that the worker will listen on for gRPC requests, if - /// any. + /// \param port The port that the worker should listen on for gRPC requests. If + /// 0, the worker should choose a random port. RayletClient(boost::asio::io_service &io_service, std::shared_ptr grpc_client, const std::string &raylet_socket, const WorkerID &worker_id, bool is_worker, const JobID &job_id, const Language &language, - ClientID *raylet_id, - std::unordered_map *internal_config, - const std::string &ip_address, int port = -1); + const std::string &ip_address, ClientID *raylet_id, int *port, + std::unordered_map *internal_config); /// Connect to the raylet via grpc only. /// @@ -174,6 +173,12 @@ class RayletClient : public PinObjectsInterface, ray::Status Disconnect() { return conn_->Disconnect(); }; + /// Tell the raylet which port this worker's gRPC server is listening on. + /// + /// \param The port. + /// \return ray::Status. + Status AnnounceWorkerPort(int port); + /// Submit a task using the raylet code path. /// /// \param The task specification. diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 92fb2e3b8..a9cec0788 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -27,26 +27,19 @@ namespace raylet { /// A constructor responsible for initializing the state of a worker. Worker::Worker(const WorkerID &worker_id, const Language &language, - const std::string &ip_address, int port, + const std::string &ip_address, std::shared_ptr connection, rpc::ClientCallManager &client_call_manager) : worker_id_(worker_id), language_(language), ip_address_(ip_address), - port_(port), + assigned_port_(-1), + port_(-1), connection_(connection), dead_(false), blocked_(false), client_call_manager_(client_call_manager), - is_detached_actor_(false) { - if (port_ > 0) { - rpc::Address addr; - addr.set_ip_address(ip_address_); - addr.set_port(port_); - rpc_client_ = std::unique_ptr( - new rpc::CoreWorkerClient(addr, client_call_manager_)); - } -} + is_detached_actor_(false) {} void Worker::MarkDead() { dead_ = true; } @@ -71,7 +64,24 @@ Language Worker::GetLanguage() const { return language_; } const std::string Worker::IpAddress() const { return ip_address_; } -int Worker::Port() const { return port_; } +int Worker::Port() const { + RAY_CHECK(port_ > 0); + return port_; +} + +int Worker::AssignedPort() const { return assigned_port_; } + +void Worker::SetAssignedPort(int port) { assigned_port_ = port; }; + +void Worker::Connect(int port) { + RAY_CHECK(port > 0); + port_ = port; + rpc::Address addr; + addr.set_ip_address(ip_address_); + addr.set_port(port_); + rpc_client_ = std::unique_ptr( + new rpc::CoreWorkerClient(addr, client_call_manager_)); +} void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id; } diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index 02c42eb14..e777b8e11 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -39,8 +39,7 @@ class Worker { /// A constructor that initializes a worker object. /// NOTE: You MUST manually set the worker process. Worker(const WorkerID &worker_id, const Language &language, - const std::string &ip_address, int port, - std::shared_ptr connection, + const std::string &ip_address, std::shared_ptr connection, rpc::ClientCallManager &client_call_manager); /// A destructor responsible for freeing all worker state. ~Worker() {} @@ -56,7 +55,11 @@ class Worker { void SetProcess(Process proc); Language GetLanguage() const; const std::string IpAddress() const; + /// Connect this worker's gRPC client. + void Connect(int port); int Port() const; + int AssignedPort() const; + void SetAssignedPort(int port); void AssignTaskId(const TaskID &task_id); const TaskID &GetAssignedTaskId() const; bool AddBlockedTaskId(const TaskID &task_id); @@ -124,7 +127,12 @@ class Worker { void SetAssignedTask(Task &assigned_task) { assigned_task_ = assigned_task; }; - rpc::CoreWorkerClient *rpc_client() { return rpc_client_.get(); } + bool IsRegistered() { return rpc_client_ != nullptr; } + + rpc::CoreWorkerClient *rpc_client() { + RAY_CHECK(IsRegistered()); + return rpc_client_.get(); + } private: /// The worker's ID. @@ -135,8 +143,12 @@ class Worker { Language language_; /// IP address of this worker. std::string ip_address_; + /// Port assigned to this worker by the raylet. If this is 0, the actual + /// port the worker listens (port_) on will be a random one. This is required + /// because a worker could crash before announcing its port, in which case + /// we still need to be able to mark that port as free. + int assigned_port_; /// Port that this worker listens on. - /// If port <= 0, this indicates that the worker will not listen to a port. int port_; /// Connection state of a worker. std::shared_ptr connection_; diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index a742d8b7d..fed72a2ab 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -55,8 +55,8 @@ namespace raylet { /// A constructor that initializes a worker pool with num_workers workers for /// each language. WorkerPool::WorkerPool(boost::asio::io_service &io_service, int num_workers, - int maximum_startup_concurrency, - std::shared_ptr gcs_client, + int maximum_startup_concurrency, int min_worker_port, + int max_worker_port, std::shared_ptr gcs_client, const WorkerCommandMap &worker_commands, const std::unordered_map &raylet_config, std::function starting_worker_timeout_callback) @@ -102,6 +102,18 @@ WorkerPool::WorkerPool(boost::asio::io_service &io_service, int num_workers, state.worker_command = entry.second; RAY_CHECK(!state.worker_command.empty()) << "Worker command must not be empty."; } + // Initialize free ports list with all ports in the specified range. + if (min_worker_port != 0) { + if (max_worker_port == 0) { + max_worker_port = 65535; // Maximum valid port number. + } + RAY_CHECK(min_worker_port > 0 && min_worker_port <= 65535); + RAY_CHECK(max_worker_port >= min_worker_port && max_worker_port <= 65535); + free_ports_ = std::unique_ptr>(new std::queue()); + for (int port = min_worker_port; port <= max_worker_port; port++) { + free_ports_->push(port); + } + } Start(num_workers); } @@ -288,15 +300,38 @@ Process WorkerPool::StartProcess(const std::vector &worker_command_ return child; } -Status WorkerPool::RegisterWorker(const std::shared_ptr &worker, pid_t pid) { - const auto port = worker->Port(); - RAY_LOG(DEBUG) << "Registering worker with pid " << pid << ", port: " << port; +Status WorkerPool::GetNextFreePort(int *port) { + if (free_ports_) { + if (free_ports_->empty()) { + return Status::Invalid( + "Ran out of ports to allocate to workers. Please specify a wider port range."); + } + *port = free_ports_->front(); + free_ports_->pop(); + } else { + *port = 0; + } + return Status::OK(); +} + +void WorkerPool::MarkPortAsFree(int port) { + if (free_ports_) { + RAY_CHECK(port != 0) << ""; + free_ports_->push(port); + } +} + +Status WorkerPool::RegisterWorker(const std::shared_ptr &worker, pid_t pid, + int *port) { auto &state = GetStateForLanguage(worker->GetLanguage()); auto it = state.starting_worker_processes.find(Process::FromPid(pid)); if (it == state.starting_worker_processes.end()) { RAY_LOG(WARNING) << "Received a register request from an unknown worker " << pid; return Status::Invalid("Unknown worker"); } + RAY_RETURN_NOT_OK(GetNextFreePort(port)); + RAY_LOG(DEBUG) << "Registering worker with pid " << pid << ", port: " << *port; + worker->SetAssignedPort(*port); worker->SetProcess(it->first); it->second--; if (it->second == 0) { @@ -307,8 +342,10 @@ Status WorkerPool::RegisterWorker(const std::shared_ptr &worker, pid_t p return Status::OK(); } -Status WorkerPool::RegisterDriver(const std::shared_ptr &driver) { +Status WorkerPool::RegisterDriver(const std::shared_ptr &driver, int *port) { RAY_CHECK(!driver->GetAssignedTaskId().IsNil()); + RAY_RETURN_NOT_OK(GetNextFreePort(port)); + driver->SetAssignedPort(*port); auto &state = GetStateForLanguage(driver->GetLanguage()); state.registered_drivers.insert(std::move(driver)); return Status::OK(); @@ -421,6 +458,7 @@ bool WorkerPool::DisconnectWorker(const std::shared_ptr &worker) { 0, {{stats::LanguageKey, Language_Name(worker->GetLanguage())}, {stats::WorkerPidKey, std::to_string(worker->GetProcess().GetId())}}); + MarkPortAsFree(worker->AssignedPort()); return RemoveWorker(state.idle, worker); } @@ -430,6 +468,7 @@ void WorkerPool::DisconnectDriver(const std::shared_ptr &driver) { stats::CurrentDriver().Record( 0, {{stats::LanguageKey, Language_Name(driver->GetLanguage())}, {stats::WorkerPidKey, std::to_string(driver->GetProcess().GetId())}}); + MarkPortAsFree(driver->AssignedPort()); } inline WorkerPool::State &WorkerPool::GetStateForLanguage(const Language &language) { @@ -453,24 +492,28 @@ std::vector> WorkerPool::GetWorkersRunningTasksForJob( return workers; } -const std::vector> WorkerPool::GetAllWorkers() const { +const std::vector> WorkerPool::GetAllRegisteredWorkers() const { std::vector> workers; for (const auto &entry : states_by_lang_) { for (const auto &worker : entry.second.registered_workers) { - workers.push_back(worker); + if (worker->IsRegistered()) { + workers.push_back(worker); + } } } return workers; } -const std::vector> WorkerPool::GetAllDrivers() const { +const std::vector> WorkerPool::GetAllRegisteredDrivers() const { std::vector> drivers; for (const auto &entry : states_by_lang_) { for (const auto &driver : entry.second.registered_drivers) { - drivers.push_back(driver); + if (driver->IsRegistered()) { + drivers.push_back(driver); + } } } diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index cf24256dc..50d18cfc3 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -54,13 +55,18 @@ class WorkerPool { /// \param maximum_startup_concurrency The maximum number of worker processes /// that can be started in parallel (typically this should be set to the number of CPU /// resources on the machine). + /// \param min_worker_port The lowest port number that workers started will bind on. + /// If this is set to 0, workers will bind on random ports. + /// \param max_worker_port The highest port number that workers started will bind on. + /// If this is not set to 0, min_worker_port must also not be set to 0. /// \param worker_commands The commands used to start the worker process, grouped by /// language. /// \param raylet_config The raylet config list of this node. /// \param starting_worker_timeout_callback The callback that will be triggered once /// it times out to start a worker. WorkerPool(boost::asio::io_service &io_service, int num_workers, - int maximum_startup_concurrency, std::shared_ptr gcs_client, + int maximum_startup_concurrency, int min_worker_port, int max_worker_port, + std::shared_ptr gcs_client, const WorkerCommandMap &worker_commands, const std::unordered_map &raylet_config, std::function starting_worker_timeout_callback); @@ -71,15 +77,20 @@ class WorkerPool { /// Register a new worker. The Worker should be added by the caller to the /// pool after it becomes idle (e.g., requests a work assignment). /// - /// \param The Worker to be registered. + /// \param[in] worker The worker to be registered. + /// \param[in] pid The PID of the worker. + /// \param[out] port The port that this worker's gRPC server should listen on. + /// Returns 0 if the worker should bind on a random port. /// \return If the registration is successful. - Status RegisterWorker(const std::shared_ptr &worker, pid_t pid); + Status RegisterWorker(const std::shared_ptr &worker, pid_t pid, int *port); /// Register a new driver. /// - /// \param The driver to be registered. + /// \param[in] worker The driver to be registered. + /// \param[out] port The port that this driver's gRPC server should listen on. + /// Returns 0 if the driver should bind on a random port. /// \return If the registration is successful. - Status RegisterDriver(const std::shared_ptr &worker); + Status RegisterDriver(const std::shared_ptr &worker, int *port); /// Get the client connection's registered worker. /// @@ -135,15 +146,15 @@ class WorkerPool { std::vector> GetWorkersRunningTasksForJob( const JobID &job_id) const; - /// Get all the workers. + /// Get all the registered workers. /// /// \return A list containing all the workers. - const std::vector> GetAllWorkers() const; + const std::vector> GetAllRegisteredWorkers() const; - /// Get all the drivers. + /// Get all the registered drivers. /// /// \return A list containing all the drivers. - const std::vector> GetAllDrivers() const; + const std::vector> GetAllRegisteredDrivers() const; /// Whether there is a pending worker for the given task. /// Note that, this is only used for actor creation task with dynamic options. @@ -248,10 +259,22 @@ class WorkerPool { /// think there are unregistered workers, and won't start new workers. void MonitorStartingWorkerProcess(const Process &proc, const Language &language); + /// Get the next unallocated port in the free ports list. If a port range isn't + /// configured, returns 0. + /// \param[out] port The next available port. + Status GetNextFreePort(int *port); + + /// Mark this port as free to be used by another worker. + /// \param[in] port The port to mark as free. + void MarkPortAsFree(int port); + /// For Process class for managing subprocesses (e.g. reaping zombies). boost::asio::io_service *io_service_; /// The maximum number of worker processes that can be started concurrently. int maximum_startup_concurrency_; + /// Keeps track of unused ports that newly-created workers can bind on. + /// If null, workers will not be passed ports and will choose them randomly. + std::unique_ptr> free_ports_; /// A client connection to the GCS. std::shared_ptr gcs_client_; /// The raylet config list of this node. diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 3c1d9e4fa..4cb35b1d4 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -40,8 +40,8 @@ class WorkerPoolMock : public WorkerPool { explicit WorkerPoolMock(boost::asio::io_service &io_service, const WorkerCommandMap &worker_commands) - : WorkerPool(io_service, 0, MAXIMUM_STARTUP_CONCURRENCY, nullptr, worker_commands, - {}, []() {}), + : WorkerPool(io_service, 0, MAXIMUM_STARTUP_CONCURRENCY, 0, 0, nullptr, + worker_commands, {}, []() {}), last_worker_process_() { states_by_lang_[ray::Language::JAVA].num_workers_per_process = NUM_WORKERS_PER_PROCESS_JAVA; @@ -96,10 +96,9 @@ class WorkerPoolMock : public WorkerPool { class WorkerPoolTest : public ::testing::Test { public: - WorkerPoolTest() - : worker_pool_(io_service_), - error_message_type_(1), - client_call_manager_(io_service_) {} + WorkerPoolTest() : error_message_type_(1), client_call_manager_(io_service_) { + worker_pool_ = std::unique_ptr(new WorkerPoolMock(io_service_)); + } std::shared_ptr CreateWorker(Process proc, const Language &language = Language::PYTHON) { @@ -115,7 +114,7 @@ class WorkerPoolTest : public ::testing::Test { ClientConnection::Create(client_handler, message_handler, std::move(socket), "worker", {}, error_message_type_); std::shared_ptr worker = std::make_shared( - WorkerID::FromRandom(), language, "127.0.0.1", -1, client, client_call_manager_); + WorkerID::FromRandom(), language, "127.0.0.1", client, client_call_manager_); if (!proc.IsNull()) { worker->SetProcess(proc); } @@ -123,8 +122,8 @@ class WorkerPoolTest : public ::testing::Test { } void SetWorkerCommands(const WorkerCommandMap &worker_commands) { - WorkerPoolMock worker_pool(io_service_, worker_commands); - this->worker_pool_ = std::move(worker_pool); + worker_pool_ = + std::unique_ptr(new WorkerPoolMock(io_service_, worker_commands)); } void TestStartupWorkerProcessCount(Language language, int num_workers_per_process, @@ -136,28 +135,28 @@ class WorkerPoolTest : public ::testing::Test { static_cast(desired_initial_worker_process_count)); Process last_started_worker_process; for (int i = 0; i < desired_initial_worker_process_count; i++) { - worker_pool_.StartWorkerProcess(language); - ASSERT_TRUE(worker_pool_.NumWorkerProcessesStarting() <= + worker_pool_->StartWorkerProcess(language); + ASSERT_TRUE(worker_pool_->NumWorkerProcessesStarting() <= expected_worker_process_count); - Process prev = worker_pool_.LastStartedWorkerProcess(); + Process prev = worker_pool_->LastStartedWorkerProcess(); if (!std::equal_to()(last_started_worker_process, prev)) { last_started_worker_process = prev; const auto &real_command = - worker_pool_.GetWorkerCommand(last_started_worker_process); + worker_pool_->GetWorkerCommand(last_started_worker_process); ASSERT_EQ(real_command, expected_worker_command); } else { - ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), + ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), expected_worker_process_count); ASSERT_TRUE(i >= expected_worker_process_count); } } // Check number of starting workers - ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), expected_worker_process_count); + ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), expected_worker_process_count); } protected: boost::asio::io_service io_service_; - WorkerPoolMock worker_pool_; + std::unique_ptr worker_pool_; int64_t error_message_type_; rpc::ClientCallManager client_call_manager_; @@ -202,7 +201,7 @@ TEST_F(WorkerPoolTest, CompareWorkerProcessObjects) { } TEST_F(WorkerPoolTest, HandleWorkerRegistration) { - Process proc = worker_pool_.StartWorkerProcess(Language::JAVA); + Process proc = worker_pool_->StartWorkerProcess(Language::JAVA); std::vector> workers; for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) { workers.push_back(CreateWorker(Process(), Language::JAVA)); @@ -210,19 +209,20 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) { for (const auto &worker : workers) { // Check that there's still a starting worker process // before all workers have been registered - ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), 1); + ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), 1); // Check that we cannot lookup the worker before it's registered. - ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), nullptr); - RAY_CHECK_OK(worker_pool_.RegisterWorker(worker, proc.GetId())); + ASSERT_EQ(worker_pool_->GetRegisteredWorker(worker->Connection()), nullptr); + int port; + RAY_CHECK_OK(worker_pool_->RegisterWorker(worker, proc.GetId(), &port)); // Check that we can lookup the worker after it's registered. - ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), worker); + ASSERT_EQ(worker_pool_->GetRegisteredWorker(worker->Connection()), worker); } // Check that there's no starting worker process - ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), 0); + ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), 0); for (const auto &worker : workers) { - worker_pool_.DisconnectWorker(worker); + worker_pool_->DisconnectWorker(worker); // Check that we cannot lookup the worker after it's disconnected. - ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), nullptr); + ASSERT_EQ(worker_pool_->GetRegisteredWorker(worker->Connection()), nullptr); } } @@ -239,21 +239,21 @@ TEST_F(WorkerPoolTest, StartupJavaWorkerProcessCount) { } TEST_F(WorkerPoolTest, InitialWorkerProcessCount) { - worker_pool_.Start(1); + worker_pool_->Start(1); // Here we try to start only 1 worker for each worker language. But since each Java // worker process contains exactly NUM_WORKERS_PER_PROCESS_JAVA (3) workers here, // it's expected to see 3 workers for Java and 1 worker for Python, instead of 1 for // each worker language. - ASSERT_NE(worker_pool_.NumWorkersStarting(), 1 * LANGUAGES.size()); - ASSERT_EQ(worker_pool_.NumWorkersStarting(), 1 + NUM_WORKERS_PER_PROCESS_JAVA); - ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), LANGUAGES.size()); + ASSERT_NE(worker_pool_->NumWorkersStarting(), 1 * LANGUAGES.size()); + ASSERT_EQ(worker_pool_->NumWorkersStarting(), 1 + NUM_WORKERS_PER_PROCESS_JAVA); + ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), LANGUAGES.size()); } TEST_F(WorkerPoolTest, HandleWorkerPushPop) { // Try to pop a worker from the empty pool and make sure we don't get one. std::shared_ptr popped_worker; const auto task_spec = ExampleTaskSpec(); - popped_worker = worker_pool_.PopWorker(task_spec); + popped_worker = worker_pool_->PopWorker(task_spec); ASSERT_EQ(popped_worker, nullptr); // Create some workers. @@ -262,17 +262,17 @@ TEST_F(WorkerPoolTest, HandleWorkerPushPop) { workers.insert(CreateWorker(Process::CreateNewDummy())); // Add the workers to the pool. for (auto &worker : workers) { - worker_pool_.PushWorker(worker); + worker_pool_->PushWorker(worker); } // Pop two workers and make sure they're one of the workers we created. - popped_worker = worker_pool_.PopWorker(task_spec); + popped_worker = worker_pool_->PopWorker(task_spec); ASSERT_NE(popped_worker, nullptr); ASSERT_TRUE(workers.count(popped_worker) > 0); - popped_worker = worker_pool_.PopWorker(task_spec); + popped_worker = worker_pool_->PopWorker(task_spec); ASSERT_NE(popped_worker, nullptr); ASSERT_TRUE(workers.count(popped_worker) > 0); - popped_worker = worker_pool_.PopWorker(task_spec); + popped_worker = worker_pool_->PopWorker(task_spec); ASSERT_EQ(popped_worker, nullptr); } @@ -280,21 +280,21 @@ TEST_F(WorkerPoolTest, PopActorWorker) { // Create a worker. auto worker = CreateWorker(Process::CreateNewDummy()); // Add the worker to the pool. - worker_pool_.PushWorker(worker); + worker_pool_->PushWorker(worker); // Assign an actor ID to the worker. const auto task_spec = ExampleTaskSpec(); - auto actor = worker_pool_.PopWorker(task_spec); + auto actor = worker_pool_->PopWorker(task_spec); const auto job_id = JobID::FromInt(1); auto actor_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1); actor->AssignActorId(actor_id); - worker_pool_.PushWorker(actor); + worker_pool_->PushWorker(actor); // Check that there are no more non-actor workers. - ASSERT_EQ(worker_pool_.PopWorker(task_spec), nullptr); + ASSERT_EQ(worker_pool_->PopWorker(task_spec), nullptr); // Check that we can pop the actor worker. const auto actor_task_spec = ExampleTaskSpec(actor_id); - actor = worker_pool_.PopWorker(actor_task_spec); + actor = worker_pool_->PopWorker(actor_task_spec); ASSERT_EQ(actor, worker); ASSERT_EQ(actor->GetActorId(), actor_id); } @@ -302,19 +302,19 @@ TEST_F(WorkerPoolTest, PopActorWorker) { TEST_F(WorkerPoolTest, PopWorkersOfMultipleLanguages) { // Create a Python Worker, and add it to the pool auto py_worker = CreateWorker(Process::CreateNewDummy(), Language::PYTHON); - worker_pool_.PushWorker(py_worker); + worker_pool_->PushWorker(py_worker); // Check that no worker will be popped if the given task is a Java task const auto java_task_spec = ExampleTaskSpec(ActorID::Nil(), Language::JAVA); - ASSERT_EQ(worker_pool_.PopWorker(java_task_spec), nullptr); + ASSERT_EQ(worker_pool_->PopWorker(java_task_spec), nullptr); // Check that the worker can be popped if the given task is a Python task const auto py_task_spec = ExampleTaskSpec(ActorID::Nil(), Language::PYTHON); - ASSERT_NE(worker_pool_.PopWorker(py_task_spec), nullptr); + ASSERT_NE(worker_pool_->PopWorker(py_task_spec), nullptr); // Create a Java Worker, and add it to the pool auto java_worker = CreateWorker(Process::CreateNewDummy(), Language::JAVA); - worker_pool_.PushWorker(java_worker); + worker_pool_->PushWorker(java_worker); // Check that the worker will be popped now for Java task - ASSERT_NE(worker_pool_.PopWorker(java_task_spec), nullptr); + ASSERT_NE(worker_pool_->PopWorker(java_task_spec), nullptr); } TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { @@ -328,9 +328,9 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { TaskSpecification task_spec = ExampleTaskSpec( ActorID::Nil(), Language::JAVA, ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1), {"test_op_0", "test_op_1"}); - worker_pool_.StartWorkerProcess(Language::JAVA, task_spec.DynamicWorkerOptions()); + worker_pool_->StartWorkerProcess(Language::JAVA, task_spec.DynamicWorkerOptions()); const auto real_command = - worker_pool_.GetWorkerCommand(worker_pool_.LastStartedWorkerProcess()); + worker_pool_->GetWorkerCommand(worker_pool_->LastStartedWorkerProcess()); ASSERT_EQ(real_command, std::vector( {"test_op_0", "dummy_java_worker_command",