Add ability to specify worker and driver ports (#8071)

This commit is contained in:
Edward Oakes
2020-05-20 15:31:13 -05:00
committed by GitHub
parent d76578700d
commit a76434ccde
25 changed files with 408 additions and 143 deletions
+4
View File
@@ -25,6 +25,7 @@ matrix:
- PYTHONWARNINGS=ignore
- RAY_DEFAULT_BUILD=1
- RAY_CYTHON_EXAMPLES=1
- RAY_USE_RANDOM_PORTS=1
install:
- . ./ci/travis/ci.sh init RAY_CI_SERVE_AFFECTED,RAY_CI_TUNE_AFFECTED,RAY_CI_PYTHON_AFFECTED
before_script:
@@ -37,6 +38,7 @@ matrix:
- PYTHONWARNINGS=ignore
- RAY_DEFAULT_BUILD=1
- RAY_CYTHON_EXAMPLES=1
- RAY_USE_RANDOM_PORTS=1
install:
- . ./ci/travis/ci.sh init RAY_CI_SERVE_AFFECTED,RAY_CI_TUNE_AFFECTED,RAY_CI_PYTHON_AFFECTED
before_script:
@@ -62,6 +64,7 @@ matrix:
- RAY_INSTALL_JAVA=1
- RAY_GCS_ACTOR_SERVICE_ENABLED=true
- PYTHON=3.6 PYTHONWARNINGS=ignore
- RAY_USE_RANDOM_PORTS=1
install:
- . ./ci/travis/ci.sh init RAY_CI_STREAMING_PYTHON_AFFECTED,RAY_CI_STREAMING_JAVA_AFFECTED
before_script:
@@ -96,6 +99,7 @@ matrix:
- RAY_INSTALL_JAVA=1
- RAY_GCS_SERVICE_ENABLED=false
- RAY_CYTHON_EXAMPLES=1
- RAY_USE_RANDOM_PORTS=1
install:
- . ./ci/travis/ci.sh init RAY_CI_ONLY_RLLIB_AFFECTED
before_script:
+2 -2
View File
@@ -5,8 +5,8 @@ PID=$$
# Print output to avoid travis killing us
watchdog() {
for i in `seq 5 5 150`; do
sleep 300
for i in `seq 2 2 150`; do
sleep 120
echo "(running, ${i}m total)"
done
echo "TIMED OUT"
@@ -85,6 +85,8 @@ public abstract class BaseMultiLanguageTest {
"start",
"--head",
"--redis-port=6379",
"--min-worker-port=0",
"--max-worker-port=0",
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
String.format("--node-manager-port=%s", nodeManagerPort),
+2
View File
@@ -77,6 +77,8 @@ class Cluster:
"num_cpus": 1,
"num_gpus": 0,
"object_store_memory": 150 * 1024 * 1024, # 150 MiB
"min_worker_port": 0,
"max_worker_port": 0,
}
if "_internal_config" in node_args:
node_args["_internal_config"] = json.loads(
+2
View File
@@ -586,6 +586,8 @@ class Node:
self._temp_dir,
self._session_dir,
self.get_resource_spec(),
self._ray_params.min_worker_port,
self._ray_params.max_worker_port,
self._ray_params.object_manager_port,
self._ray_params.redis_password,
use_valgrind=use_valgrind,
+34
View File
@@ -1,4 +1,5 @@
import logging
import os
import numpy as np
@@ -35,6 +36,10 @@ class RayParams:
node_ip_address (str): The IP address of the node that we are on.
raylet_ip_address (str): The IP address of the raylet that this node
connects to.
min_worker_port (int): The lowest port number that workers will bind
on. If not set or set to 0, random ports will be chosen.
max_worker_port (int): The highest port number that workers will bind
on. If set, min_worker_port must also be set.
object_id_seed (int): Used to seed the deterministic generation of
object IDs. The same value can be used across multiple runs of the
same job in order to generate the object IDs in a consistent
@@ -98,6 +103,8 @@ class RayParams:
node_manager_port=None,
node_ip_address=None,
raylet_ip_address=None,
min_worker_port=None,
max_worker_port=None,
object_id_seed=None,
driver_mode=None,
redirect_worker_output=None,
@@ -135,6 +142,8 @@ class RayParams:
self.node_manager_port = node_manager_port
self.node_ip_address = node_ip_address
self.raylet_ip_address = raylet_ip_address
self.min_worker_port = min_worker_port
self.max_worker_port = max_worker_port
self.driver_mode = driver_mode
self.redirect_worker_output = redirect_worker_output
self.redirect_output = redirect_output
@@ -189,6 +198,31 @@ class RayParams:
self._check_usage()
def _check_usage(self):
# Used primarily for testing.
if os.environ.get("RAY_USE_RANDOM_PORTS", False):
if self.min_worker_port is None and self.min_worker_port is None:
self.min_worker_port = 0
self.max_worker_port = 0
if self.min_worker_port is not None:
if self.min_worker_port != 0 and (self.min_worker_port < 1024
or self.min_worker_port > 65535):
raise ValueError("min_worker_port must be 0 or an integer "
"between 1024 and 65535.")
if self.max_worker_port is not None:
if self.min_worker_port is None:
raise ValueError("If max_worker_port is set, min_worker_port "
"must also be set.")
elif self.max_worker_port != 0:
if self.max_worker_port < 1024 or self.max_worker_port > 65535:
raise ValueError(
"max_worker_port must be 0 or an integer between "
"1024 and 65535.")
elif self.max_worker_port <= self.min_worker_port:
raise ValueError("max_worker_port must be higher than "
"min_worker_port.")
if self.resources is not None:
assert "CPU" not in self.resources, (
"'CPU' should not be included in the resource dictionary. Use "
+21 -4
View File
@@ -170,6 +170,20 @@ def dashboard(cluster_config_file, cluster_name, port):
required=False,
type=int,
help="the port to use for starting the node manager")
@click.option(
"--min-worker-port",
required=False,
type=int,
default=10000,
help="the lowest port number that workers will bind on. If not set, "
"random ports will be chosen.")
@click.option(
"--max-worker-port",
required=False,
type=int,
default=10999,
help="the highest port number that workers will bind on. If set, "
"'--min-worker-port' must also be set.")
@click.option(
"--memory",
required=False,
@@ -289,10 +303,11 @@ def dashboard(cluster_config_file, cluster_name, port):
help="Specify whether load code from local file or GCS serialization.")
def start(node_ip_address, redis_address, address, redis_port, port,
num_redis_shards, redis_max_clients, redis_password,
redis_shard_ports, object_manager_port, node_manager_port, memory,
object_store_memory, redis_max_memory, num_cpus, num_gpus, resources,
head, include_webui, webui_host, block, plasma_directory, huge_pages,
autoscaling_config, no_redirect_worker_output, no_redirect_output,
redis_shard_ports, object_manager_port, node_manager_port,
min_worker_port, max_worker_port, memory, object_store_memory,
redis_max_memory, num_cpus, num_gpus, resources, head, include_webui,
webui_host, block, plasma_directory, huge_pages, autoscaling_config,
no_redirect_worker_output, no_redirect_output,
plasma_store_socket_name, raylet_socket_name, temp_dir, include_java,
java_worker_options, load_code_from_local, internal_config):
"""Start Ray processes manually on the local machine."""
@@ -327,6 +342,8 @@ def start(node_ip_address, redis_address, address, redis_port, port,
redirect_output = None if not no_redirect_output else True
ray_params = ray.parameter.RayParams(
node_ip_address=node_ip_address,
min_worker_port=min_worker_port,
max_worker_port=max_worker_port,
object_manager_port=object_manager_port,
node_manager_port=node_manager_port,
memory=memory,
+14
View File
@@ -1219,6 +1219,8 @@ def start_raylet(redis_address,
temp_dir,
session_dir,
resource_spec,
min_worker_port=None,
max_worker_port=None,
object_manager_port=None,
redis_password=None,
use_valgrind=False,
@@ -1247,6 +1249,10 @@ def start_raylet(redis_address,
resource_spec (ResourceSpec): Resources for this raylet.
object_manager_port: The port to use for the object manager. If this is
None, then the object manager will choose its own port.
min_worker_port (int): The lowest port number that workers will bind
on. If not set, random ports will be chosen.
max_worker_port (int): The highest port number that workers will bind
on. If set, min_worker_port must also be set.
redis_password: The password to use when connecting to Redis.
use_valgrind (bool): True if the raylet should be started inside
of valgrind. If this is True, use_profiler must be False.
@@ -1324,6 +1330,12 @@ def start_raylet(redis_address,
if object_manager_port is None:
object_manager_port = 0
if min_worker_port is None:
min_worker_port = 0
if max_worker_port is None:
max_worker_port = 0
if load_code_from_local:
start_worker_command += ["--load-code-from-local"]
@@ -1332,6 +1344,8 @@ def start_raylet(redis_address,
"--raylet_socket_name={}".format(raylet_name),
"--store_socket_name={}".format(plasma_store_name),
"--object_manager_port={}".format(object_manager_port),
"--min_worker_port={}".format(min_worker_port),
"--max_worker_port={}".format(max_worker_port),
"--node_manager_port={}".format(node_manager_port),
"--node_ip_address={}".format(node_ip_address),
"--redis_address={}".format(gcs_ip_address),
+4 -1
View File
@@ -153,7 +153,10 @@ def ray_start_object_store_memory(request):
@pytest.fixture
def call_ray_start(request):
parameter = getattr(request, "param", "ray start --head --num-cpus=1")
parameter = getattr(
request, "param",
"ray start --head --num-cpus=1 --min-worker-port=0 --max-worker-port=0"
)
command_args = parameter.split(" ")
out = ray.utils.decode(
subprocess.check_output(command_args, stderr=subprocess.STDOUT))
+4 -2
View File
@@ -43,8 +43,10 @@ def test_dynamic_res_deletion(shutdown_only):
available_res = ray.available_resources()
cluster_res = ray.cluster_resources()
assert res_name not in available_res
assert res_name not in cluster_res
def check_resources():
return res_name not in available_res and res_name not in cluster_res
ray.test_utils.wait_for_condition(check_resources)
def test_dynamic_res_infeasible_rescheduling(ray_start_regular):
+14 -3
View File
@@ -285,7 +285,10 @@ ray.get([a.log.remote(), f.remote()])
@pytest.mark.parametrize(
"call_ray_start", ["ray start --head --num-cpus=1 --num-gpus=1"],
"call_ray_start", [
"ray start --head --num-cpus=1 --num-gpus=1 " +
"--min-worker-port=0 --max-worker-port=0"
],
indirect=True)
def test_drivers_release_resources(call_ray_start):
address = call_ray_start
@@ -334,6 +337,7 @@ print("success")
print(output_line)
if output_line == "success":
return
time.sleep(1)
raise RayTestTimeoutException(
"Timed out waiting for process to print success.")
@@ -376,6 +380,13 @@ def test_calling_start_ray_head(call_ray_stop_only):
])
subprocess.check_output(["ray", "stop"])
# Test starting Ray with the worker port range specified.
subprocess.check_output([
"ray", "start", "--head", "--min-worker-port", "50000",
"--max-worker-port", "51000"
])
subprocess.check_output(["ray", "stop"])
# Test starting Ray with the number of CPUs specified.
subprocess.check_output(["ray", "start", "--head", "--num-cpus", "2"])
subprocess.check_output(["ray", "stop"])
@@ -419,7 +430,7 @@ def test_calling_start_ray_head(call_ray_stop_only):
assert blocked.returncode is None
kill_process_by_name("raylet")
wait_for_children_of_pid_to_exit(blocked.pid, timeout=120)
wait_for_children_of_pid_to_exit(blocked.pid, timeout=30)
blocked.wait()
assert blocked.returncode != 0, "ray start shouldn't return 0 on bad exit"
@@ -431,7 +442,7 @@ def test_calling_start_ray_head(call_ray_stop_only):
wait_for_children_of_pid(blocked.pid, num_children=7, timeout=30)
blocked.terminate()
wait_for_children_of_pid_to_exit(blocked.pid, timeout=120)
wait_for_children_of_pid_to_exit(blocked.pid, timeout=30)
blocked.wait()
assert blocked.returncode != 0, "ray start shouldn't return 0 on bad exit"
+19 -9
View File
@@ -244,8 +244,6 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
client_call_manager_(new rpc::ClientCallManager(io_service_)),
death_check_timer_(io_service_),
internal_timer_(io_service_),
core_worker_server_(WorkerTypeString(options_.worker_type),
0 /* let grpc choose a port */),
task_queue_length_(0),
num_executed_tasks_(0),
task_execution_service_work_(task_execution_service_),
@@ -266,10 +264,6 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
[this] { return local_raylet_client_->TaskDone(); }));
}
// Start RPC server after all the task receivers are properly initialized.
core_worker_server_.RegisterService(grpc_service_);
core_worker_server_.Run();
// Initialize raylet client.
// NOTE(edoakes): the core_worker_server_ must be running before registering with
// the raylet, as the raylet will start sending some RPC messages immediately.
@@ -280,21 +274,37 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
auto grpc_client = rpc::NodeManagerWorkerClient::make(
options_.raylet_ip_address, options_.node_manager_port, *client_call_manager_);
ClientID local_raylet_id;
int assigned_port;
std::unordered_map<std::string, std::string> internal_config;
local_raylet_client_ = std::shared_ptr<raylet::RayletClient>(new raylet::RayletClient(
io_service_, std::move(grpc_client), options_.raylet_socket, GetWorkerID(),
(options_.worker_type == ray::WorkerType::WORKER),
worker_context_.GetCurrentJobID(), options_.language, &local_raylet_id,
&internal_config, options_.node_ip_address, core_worker_server_.GetPort()));
worker_context_.GetCurrentJobID(), options_.language, options_.node_ip_address,
&local_raylet_id, &assigned_port, &internal_config));
connected_ = true;
RAY_CHECK(assigned_port != -1)
<< "Failed to allocate a port for the worker. Please specify a wider port range "
"using the '--min-worker-port' and '--max-worker-port' arguments to 'ray "
"start'.";
// NOTE(edoakes): any initialization depending on RayConfig must happen after this line.
RayConfig::instance().initialize(internal_config);
// Start RPC server after all the task receivers are properly initialized and we have
// our assigned port from the raylet.
core_worker_server_ = std::unique_ptr<rpc::GrpcServer>(
new rpc::GrpcServer(WorkerTypeString(options_.worker_type), assigned_port));
core_worker_server_->RegisterService(grpc_service_);
core_worker_server_->Run();
// Tell the raylet the port that we are listening on.
RAY_CHECK_OK(local_raylet_client_->AnnounceWorkerPort(core_worker_server_->GetPort()));
// Set our own address.
RAY_CHECK(!local_raylet_id.IsNil());
rpc_address_.set_ip_address(options_.node_ip_address);
rpc_address_.set_port(core_worker_server_.GetPort());
rpc_address_.set_port(core_worker_server_->GetPort());
rpc_address_.set_raylet_id(local_raylet_id.Binary());
rpc_address_.set_worker_id(worker_context_.GetWorkerID().Binary());
RAY_LOG(INFO) << "Initializing worker at address: " << rpc_address_.ip_address() << ":"
+1 -1
View File
@@ -946,7 +946,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler {
boost::asio::steady_timer internal_timer_;
/// RPC server used to receive tasks to execute.
rpc::GrpcServer core_worker_server_;
std::unique_ptr<rpc::GrpcServer> core_worker_server_;
/// Address of our RPC server.
rpc::Address rpc_address_;
@@ -185,6 +185,8 @@ class CoreWorkerTest : public RedisServiceManagerForTest {
.append(" --node_ip_address=" + node_ip_address)
.append(" --redis_address=" + redis_address)
.append(" --redis_port=6379")
.append(" --min-worker-port=0")
.append(" --max-worker-port=0")
.append(" --num_initial_workers=1")
.append(" --maximum_startup_concurrency=10")
.append(" --static_resource_list=" + resource)
+9
View File
@@ -33,6 +33,8 @@ enum MessageType:int {
// Send a reply confirming the successful registration of a worker or driver.
// This is sent from the raylet to a worker or driver.
RegisterClientReply,
// Send the worker's gRPC port to the raylet.
AnnounceWorkerPort,
// Notify the raylet that this client is disconnecting unexpectedly.
// This is sent from a worker to a raylet.
DisconnectClient,
@@ -160,12 +162,19 @@ table RegisterClientRequest {
table RegisterClientReply {
// GCS ClientID of the local node manager.
raylet_id: string;
// Port that this worker should listen on.
port: int;
// Keys for internal config options.
internal_config_keys: [string];
// Values for internal config options corresponding to keys above.
internal_config_values: [string];
}
table AnnounceWorkerPort {
// Port that this worker is listening on.
port: int;
}
table RegisterNodeManagerRequest {
// GCS ClientID of the connecting node manager.
client_id: string;
+8
View File
@@ -31,6 +31,10 @@ DEFINE_int32(node_manager_port, -1, "The port of node manager.");
DEFINE_string(node_ip_address, "", "The ip address of this node.");
DEFINE_string(redis_address, "", "The ip address of redis server.");
DEFINE_int32(redis_port, -1, "The port of redis server.");
DEFINE_int32(min_worker_port, 0,
"The lowest port that workers' gRPC servers will bind on.");
DEFINE_int32(max_worker_port, 0,
"The highest port that workers' gRPC servers will bind on.");
DEFINE_int32(num_initial_workers, 0, "Number of initial workers.");
DEFINE_int32(maximum_startup_concurrency, 1, "Maximum startup concurrency");
DEFINE_string(static_resource_list, "", "The static resource list of this node.");
@@ -62,6 +66,8 @@ int main(int argc, char *argv[]) {
const std::string node_ip_address = FLAGS_node_ip_address;
const std::string redis_address = FLAGS_redis_address;
const int redis_port = static_cast<int>(FLAGS_redis_port);
const int min_worker_port = static_cast<int>(FLAGS_min_worker_port);
const int max_worker_port = static_cast<int>(FLAGS_max_worker_port);
const int num_initial_workers = static_cast<int>(FLAGS_num_initial_workers);
const int maximum_startup_concurrency =
static_cast<int>(FLAGS_maximum_startup_concurrency);
@@ -121,6 +127,8 @@ int main(int argc, char *argv[]) {
node_manager_config.node_manager_port = node_manager_port;
node_manager_config.num_initial_workers = num_initial_workers;
node_manager_config.maximum_startup_concurrency = maximum_startup_concurrency;
node_manager_config.min_worker_port = min_worker_port;
node_manager_config.max_worker_port = max_worker_port;
if (!python_worker_command.empty()) {
node_manager_config.worker_commands.emplace(
+60 -30
View File
@@ -141,7 +141,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
local_available_resources_(config.resource_config),
worker_pool_(
io_service, config.num_initial_workers, config.maximum_startup_concurrency,
gcs_client_, config.worker_commands, config.raylet_config,
config.min_worker_port, config.max_worker_port, gcs_client_,
config.worker_commands, config.raylet_config,
/*starting_worker_timeout_callback=*/
[this]() { this->DispatchTasks(this->local_queues_.GetReadyTasksByClass()); }),
scheduling_policy_(local_queues_),
@@ -398,8 +399,8 @@ void NodeManager::Heartbeat() {
}
void NodeManager::DoLocalGC() {
auto all_workers = worker_pool_.GetAllWorkers();
for (const auto &driver : worker_pool_.GetAllDrivers()) {
auto all_workers = worker_pool_.GetAllRegisteredWorkers();
for (const auto &driver : worker_pool_.GetAllRegisteredDrivers()) {
all_workers.push_back(driver);
}
RAY_LOG(WARNING) << "Sending local GC request to " << all_workers.size() << " workers.";
@@ -995,6 +996,9 @@ void NodeManager::ProcessClientMessage(const std::shared_ptr<ClientConnection> &
case protocol::MessageType::RegisterClientRequest: {
ProcessRegisterClientRequestMessage(client, message_data);
} break;
case protocol::MessageType::AnnounceWorkerPort: {
ProcessAnnounceWorkerPortMessage(client, message_data);
} break;
case protocol::MessageType::TaskDone: {
HandleWorkerAvailable(client);
} break;
@@ -1086,37 +1090,21 @@ void NodeManager::ProcessClientMessage(const std::shared_ptr<ClientConnection> &
void NodeManager::ProcessRegisterClientRequestMessage(
const std::shared_ptr<ClientConnection> &client, const uint8_t *message_data) {
client->Register();
flatbuffers::FlatBufferBuilder fbb;
std::vector<std::string> internal_config_keys;
std::vector<std::string> internal_config_values;
for (auto kv : initial_config_.raylet_config) {
internal_config_keys.push_back(kv.first);
internal_config_values.push_back(kv.second);
}
auto reply = ray::protocol::CreateRegisterClientReply(
fbb, to_flatbuf(fbb, self_node_id_),
string_vec_to_flatbuf(fbb, internal_config_keys),
string_vec_to_flatbuf(fbb, internal_config_values));
fbb.Finish(reply);
client->WriteMessageAsync(
static_cast<int64_t>(protocol::MessageType::RegisterClientReply), fbb.GetSize(),
fbb.GetBufferPointer(), [this, client](const ray::Status &status) {
if (!status.ok()) {
ProcessDisconnectClientMessage(client);
}
});
auto message = flatbuffers::GetRoot<protocol::RegisterClientRequest>(message_data);
Language language = static_cast<Language>(message->language());
WorkerID worker_id = from_flatbuf<WorkerID>(*message->worker_id());
pid_t pid = message->worker_pid();
std::string worker_ip_address = string_from_flatbuf(*message->ip_address());
auto worker = std::make_shared<Worker>(worker_id, language, worker_ip_address,
message->port(), client, client_call_manager_);
auto worker = std::make_shared<Worker>(worker_id, language, worker_ip_address, client,
client_call_manager_);
int assigned_port;
if (message->is_worker()) {
// Register the new worker.
if (worker_pool_.RegisterWorker(worker, pid).ok()) {
HandleWorkerAvailable(worker->Connection());
if (!worker_pool_.RegisterWorker(worker, pid, &assigned_port).ok()) {
// Return -1 to signal to the worker that registration failed.
assigned_port = -1;
}
} else {
// Register the new driver.
@@ -1127,14 +1115,56 @@ void NodeManager::ProcessRegisterClientRequestMessage(
const TaskID driver_task_id = TaskID::ComputeDriverTaskId(worker_id);
worker->AssignTaskId(driver_task_id);
worker->AssignJobId(job_id);
Status status = worker_pool_.RegisterDriver(worker);
Status status = worker_pool_.RegisterDriver(worker, &assigned_port);
if (status.ok()) {
local_queues_.AddDriverTaskId(driver_task_id);
auto job_data_ptr = gcs::CreateJobTableData(
job_id, /*is_dead*/ false, std::time(nullptr), worker_ip_address, pid);
RAY_CHECK_OK(gcs_client_->Jobs().AsyncAdd(job_data_ptr, nullptr));
} else {
// Return -1 to signal to the worker that registration failed.
assigned_port = -1;
}
}
flatbuffers::FlatBufferBuilder fbb;
std::vector<std::string> internal_config_keys;
std::vector<std::string> internal_config_values;
for (auto kv : initial_config_.raylet_config) {
internal_config_keys.push_back(kv.first);
internal_config_values.push_back(kv.second);
}
auto reply = ray::protocol::CreateRegisterClientReply(
fbb, to_flatbuf(fbb, self_node_id_), assigned_port,
string_vec_to_flatbuf(fbb, internal_config_keys),
string_vec_to_flatbuf(fbb, internal_config_values));
fbb.Finish(reply);
client->WriteMessageAsync(
static_cast<int64_t>(protocol::MessageType::RegisterClientReply), fbb.GetSize(),
fbb.GetBufferPointer(), [this, client](const ray::Status &status) {
if (!status.ok()) {
ProcessDisconnectClientMessage(client);
}
});
}
void NodeManager::ProcessAnnounceWorkerPortMessage(
const std::shared_ptr<ClientConnection> &client, const uint8_t *message_data) {
bool is_worker = true;
std::shared_ptr<Worker> worker = worker_pool_.GetRegisteredWorker(client);
if (worker == nullptr) {
is_worker = false;
worker = worker_pool_.GetRegisteredDriver(client);
}
RAY_CHECK(worker != nullptr) << "No worker exists for CoreWorker with client: "
<< client->DebugString();
auto message = flatbuffers::GetRoot<protocol::AnnounceWorkerPort>(message_data);
int port = message->port();
worker->Connect(port);
if (is_worker) {
HandleWorkerAvailable(worker->Connection());
}
}
void NodeManager::HandleDisconnectedActor(const ActorID &actor_id, bool was_local,
@@ -2662,7 +2692,7 @@ std::shared_ptr<ActorTableData> NodeManager::CreateActorTableDataFromCreationTas
}
void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) {
RAY_LOG(INFO) << "Finishing assigned actor task";
RAY_LOG(DEBUG) << "Finishing assigned actor task";
ActorID actor_id;
TaskID caller_id;
const TaskSpecification task_spec = task.GetTaskSpecification();
@@ -3520,9 +3550,9 @@ void NodeManager::HandleGetNodeStats(const rpc::GetNodeStatsRequest &node_stats_
// and return the information from HandleNodesStatsRequest. The caller of
// HandleGetNodeStats should set a timeout so that the rpc finishes even if not all
// workers have replied.
auto all_workers = worker_pool_.GetAllWorkers();
auto all_workers = worker_pool_.GetAllRegisteredWorkers();
absl::flat_hash_set<WorkerID> driver_ids;
for (auto driver : worker_pool_.GetAllDrivers()) {
for (auto driver : worker_pool_.GetAllRegisteredDrivers()) {
all_workers.push_back(driver);
driver_ids.insert(driver->WorkerId());
}
+14
View File
@@ -58,6 +58,12 @@ struct NodeManagerConfig {
/// The port to use for listening to incoming connections. If this is 0 then
/// the node manager will choose its own port.
int node_manager_port;
/// The lowest port number that workers started will bind on.
/// If this is set to 0, workers will bind on random ports.
int min_worker_port;
/// The highest port number that workers started will bind on.
/// If this is not set to 0, min_worker_port must also not be set to 0.
int max_worker_port;
/// The initial number of workers to create.
int num_initial_workers;
/// The maximum number of workers that can be started concurrently by a
@@ -453,6 +459,14 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
void ProcessRegisterClientRequestMessage(
const std::shared_ptr<ClientConnection> &client, const uint8_t *message_data);
/// Process client message of AnnounceWorkerPort
///
/// \param client The client that sent the message.
/// \param message_data A pointer to the message data.
/// \return Void.
void ProcessAnnounceWorkerPortMessage(const std::shared_ptr<ClientConnection> &client,
const uint8_t *message_data);
/// Handle the case that a worker is available.
///
/// \param client The connection for the worker.
+13 -5
View File
@@ -163,11 +163,11 @@ raylet::RayletClient::RayletClient(
raylet::RayletClient::RayletClient(
boost::asio::io_service &io_service,
std::shared_ptr<rpc::NodeManagerWorkerClient> grpc_client,
std::shared_ptr<ray::rpc::NodeManagerWorkerClient> grpc_client,
const std::string &raylet_socket, const WorkerID &worker_id, bool is_worker,
const JobID &job_id, const Language &language, ClientID *raylet_id,
std::unordered_map<std::string, std::string> *internal_config,
const std::string &ip_address, int port)
const JobID &job_id, const Language &language, const std::string &ip_address,
ClientID *raylet_id, int *port,
std::unordered_map<std::string, std::string> *internal_config)
: grpc_client_(std::move(grpc_client)), worker_id_(worker_id), job_id_(job_id) {
// For C++14, we could use std::make_unique
conn_ = std::unique_ptr<raylet::RayletConnection>(
@@ -176,7 +176,7 @@ raylet::RayletClient::RayletClient(
flatbuffers::FlatBufferBuilder fbb;
auto message = protocol::CreateRegisterClientRequest(
fbb, is_worker, to_flatbuf(fbb, worker_id), getpid(), to_flatbuf(fbb, job_id),
language, fbb.CreateString(ip_address), port);
language, fbb.CreateString(ip_address));
fbb.Finish(message);
// Register the process ID with the raylet.
// NOTE(swang): If raylet exits and we are registered as a worker, we will get killed.
@@ -186,6 +186,7 @@ raylet::RayletClient::RayletClient(
RAY_CHECK_OK_PREPEND(status, "[RayletClient] Unable to register worker with raylet.");
auto reply_message = flatbuffers::GetRoot<protocol::RegisterClientReply>(reply.get());
*raylet_id = ClientID::FromBinary(reply_message->raylet_id()->str());
*port = reply_message->port();
RAY_CHECK(internal_config);
auto keys = reply_message->internal_config_keys();
@@ -196,6 +197,13 @@ raylet::RayletClient::RayletClient(
}
}
Status raylet::RayletClient::AnnounceWorkerPort(int port) {
flatbuffers::FlatBufferBuilder fbb;
auto message = protocol::CreateAnnounceWorkerPort(fbb, port);
fbb.Finish(message);
return conn_->WriteMessage(MessageType::AnnounceWorkerPort, &fbb);
}
Status raylet::RayletClient::SubmitTask(const TaskSpecification &task_spec) {
flatbuffers::FlatBufferBuilder fbb;
auto message =
+11 -6
View File
@@ -153,19 +153,18 @@ class RayletClient : public PinObjectsInterface,
/// additional message will be sent to register as one.
/// \param job_id The ID of the driver. This is non-nil if the client is a driver.
/// \param language Language of the worker.
/// \param ip_address The IP address of the worker.
/// \param raylet_id This will be populated with the local raylet's ClientID.
/// \param internal_config This will be populated with internal config parameters
/// provided by the raylet.
/// \param ip_address The IP address of the worker.
/// \param port The port that the worker will listen on for gRPC requests, if
/// any.
/// \param port The port that the worker should listen on for gRPC requests. If
/// 0, the worker should choose a random port.
RayletClient(boost::asio::io_service &io_service,
std::shared_ptr<ray::rpc::NodeManagerWorkerClient> grpc_client,
const std::string &raylet_socket, const WorkerID &worker_id,
bool is_worker, const JobID &job_id, const Language &language,
ClientID *raylet_id,
std::unordered_map<std::string, std::string> *internal_config,
const std::string &ip_address, int port = -1);
const std::string &ip_address, ClientID *raylet_id, int *port,
std::unordered_map<std::string, std::string> *internal_config);
/// Connect to the raylet via grpc only.
///
@@ -174,6 +173,12 @@ class RayletClient : public PinObjectsInterface,
ray::Status Disconnect() { return conn_->Disconnect(); };
/// Tell the raylet which port this worker's gRPC server is listening on.
///
/// \param The port.
/// \return ray::Status.
Status AnnounceWorkerPort(int port);
/// Submit a task using the raylet code path.
///
/// \param The task specification.
+22 -12
View File
@@ -27,26 +27,19 @@ namespace raylet {
/// A constructor responsible for initializing the state of a worker.
Worker::Worker(const WorkerID &worker_id, const Language &language,
const std::string &ip_address, int port,
const std::string &ip_address,
std::shared_ptr<ClientConnection> connection,
rpc::ClientCallManager &client_call_manager)
: worker_id_(worker_id),
language_(language),
ip_address_(ip_address),
port_(port),
assigned_port_(-1),
port_(-1),
connection_(connection),
dead_(false),
blocked_(false),
client_call_manager_(client_call_manager),
is_detached_actor_(false) {
if (port_ > 0) {
rpc::Address addr;
addr.set_ip_address(ip_address_);
addr.set_port(port_);
rpc_client_ = std::unique_ptr<rpc::CoreWorkerClient>(
new rpc::CoreWorkerClient(addr, client_call_manager_));
}
}
is_detached_actor_(false) {}
void Worker::MarkDead() { dead_ = true; }
@@ -71,7 +64,24 @@ Language Worker::GetLanguage() const { return language_; }
const std::string Worker::IpAddress() const { return ip_address_; }
int Worker::Port() const { return port_; }
int Worker::Port() const {
RAY_CHECK(port_ > 0);
return port_;
}
int Worker::AssignedPort() const { return assigned_port_; }
void Worker::SetAssignedPort(int port) { assigned_port_ = port; };
void Worker::Connect(int port) {
RAY_CHECK(port > 0);
port_ = port;
rpc::Address addr;
addr.set_ip_address(ip_address_);
addr.set_port(port_);
rpc_client_ = std::unique_ptr<rpc::CoreWorkerClient>(
new rpc::CoreWorkerClient(addr, client_call_manager_));
}
void Worker::AssignTaskId(const TaskID &task_id) { assigned_task_id_ = task_id; }
+16 -4
View File
@@ -39,8 +39,7 @@ class Worker {
/// A constructor that initializes a worker object.
/// NOTE: You MUST manually set the worker process.
Worker(const WorkerID &worker_id, const Language &language,
const std::string &ip_address, int port,
std::shared_ptr<ClientConnection> connection,
const std::string &ip_address, std::shared_ptr<ClientConnection> connection,
rpc::ClientCallManager &client_call_manager);
/// A destructor responsible for freeing all worker state.
~Worker() {}
@@ -56,7 +55,11 @@ class Worker {
void SetProcess(Process proc);
Language GetLanguage() const;
const std::string IpAddress() const;
/// Connect this worker's gRPC client.
void Connect(int port);
int Port() const;
int AssignedPort() const;
void SetAssignedPort(int port);
void AssignTaskId(const TaskID &task_id);
const TaskID &GetAssignedTaskId() const;
bool AddBlockedTaskId(const TaskID &task_id);
@@ -124,7 +127,12 @@ class Worker {
void SetAssignedTask(Task &assigned_task) { assigned_task_ = assigned_task; };
rpc::CoreWorkerClient *rpc_client() { return rpc_client_.get(); }
bool IsRegistered() { return rpc_client_ != nullptr; }
rpc::CoreWorkerClient *rpc_client() {
RAY_CHECK(IsRegistered());
return rpc_client_.get();
}
private:
/// The worker's ID.
@@ -135,8 +143,12 @@ class Worker {
Language language_;
/// IP address of this worker.
std::string ip_address_;
/// Port assigned to this worker by the raylet. If this is 0, the actual
/// port the worker listens (port_) on will be a random one. This is required
/// because a worker could crash before announcing its port, in which case
/// we still need to be able to mark that port as free.
int assigned_port_;
/// Port that this worker listens on.
/// If port <= 0, this indicates that the worker will not listen to a port.
int port_;
/// Connection state of a worker.
std::shared_ptr<ClientConnection> connection_;
+53 -10
View File
@@ -55,8 +55,8 @@ namespace raylet {
/// A constructor that initializes a worker pool with num_workers workers for
/// each language.
WorkerPool::WorkerPool(boost::asio::io_service &io_service, int num_workers,
int maximum_startup_concurrency,
std::shared_ptr<gcs::GcsClient> gcs_client,
int maximum_startup_concurrency, int min_worker_port,
int max_worker_port, std::shared_ptr<gcs::GcsClient> gcs_client,
const WorkerCommandMap &worker_commands,
const std::unordered_map<std::string, std::string> &raylet_config,
std::function<void()> starting_worker_timeout_callback)
@@ -102,6 +102,18 @@ WorkerPool::WorkerPool(boost::asio::io_service &io_service, int num_workers,
state.worker_command = entry.second;
RAY_CHECK(!state.worker_command.empty()) << "Worker command must not be empty.";
}
// Initialize free ports list with all ports in the specified range.
if (min_worker_port != 0) {
if (max_worker_port == 0) {
max_worker_port = 65535; // Maximum valid port number.
}
RAY_CHECK(min_worker_port > 0 && min_worker_port <= 65535);
RAY_CHECK(max_worker_port >= min_worker_port && max_worker_port <= 65535);
free_ports_ = std::unique_ptr<std::queue<int>>(new std::queue<int>());
for (int port = min_worker_port; port <= max_worker_port; port++) {
free_ports_->push(port);
}
}
Start(num_workers);
}
@@ -288,15 +300,38 @@ Process WorkerPool::StartProcess(const std::vector<std::string> &worker_command_
return child;
}
Status WorkerPool::RegisterWorker(const std::shared_ptr<Worker> &worker, pid_t pid) {
const auto port = worker->Port();
RAY_LOG(DEBUG) << "Registering worker with pid " << pid << ", port: " << port;
Status WorkerPool::GetNextFreePort(int *port) {
if (free_ports_) {
if (free_ports_->empty()) {
return Status::Invalid(
"Ran out of ports to allocate to workers. Please specify a wider port range.");
}
*port = free_ports_->front();
free_ports_->pop();
} else {
*port = 0;
}
return Status::OK();
}
void WorkerPool::MarkPortAsFree(int port) {
if (free_ports_) {
RAY_CHECK(port != 0) << "";
free_ports_->push(port);
}
}
Status WorkerPool::RegisterWorker(const std::shared_ptr<Worker> &worker, pid_t pid,
int *port) {
auto &state = GetStateForLanguage(worker->GetLanguage());
auto it = state.starting_worker_processes.find(Process::FromPid(pid));
if (it == state.starting_worker_processes.end()) {
RAY_LOG(WARNING) << "Received a register request from an unknown worker " << pid;
return Status::Invalid("Unknown worker");
}
RAY_RETURN_NOT_OK(GetNextFreePort(port));
RAY_LOG(DEBUG) << "Registering worker with pid " << pid << ", port: " << *port;
worker->SetAssignedPort(*port);
worker->SetProcess(it->first);
it->second--;
if (it->second == 0) {
@@ -307,8 +342,10 @@ Status WorkerPool::RegisterWorker(const std::shared_ptr<Worker> &worker, pid_t p
return Status::OK();
}
Status WorkerPool::RegisterDriver(const std::shared_ptr<Worker> &driver) {
Status WorkerPool::RegisterDriver(const std::shared_ptr<Worker> &driver, int *port) {
RAY_CHECK(!driver->GetAssignedTaskId().IsNil());
RAY_RETURN_NOT_OK(GetNextFreePort(port));
driver->SetAssignedPort(*port);
auto &state = GetStateForLanguage(driver->GetLanguage());
state.registered_drivers.insert(std::move(driver));
return Status::OK();
@@ -421,6 +458,7 @@ bool WorkerPool::DisconnectWorker(const std::shared_ptr<Worker> &worker) {
0, {{stats::LanguageKey, Language_Name(worker->GetLanguage())},
{stats::WorkerPidKey, std::to_string(worker->GetProcess().GetId())}});
MarkPortAsFree(worker->AssignedPort());
return RemoveWorker(state.idle, worker);
}
@@ -430,6 +468,7 @@ void WorkerPool::DisconnectDriver(const std::shared_ptr<Worker> &driver) {
stats::CurrentDriver().Record(
0, {{stats::LanguageKey, Language_Name(driver->GetLanguage())},
{stats::WorkerPidKey, std::to_string(driver->GetProcess().GetId())}});
MarkPortAsFree(driver->AssignedPort());
}
inline WorkerPool::State &WorkerPool::GetStateForLanguage(const Language &language) {
@@ -453,24 +492,28 @@ std::vector<std::shared_ptr<Worker>> WorkerPool::GetWorkersRunningTasksForJob(
return workers;
}
const std::vector<std::shared_ptr<Worker>> WorkerPool::GetAllWorkers() const {
const std::vector<std::shared_ptr<Worker>> WorkerPool::GetAllRegisteredWorkers() const {
std::vector<std::shared_ptr<Worker>> workers;
for (const auto &entry : states_by_lang_) {
for (const auto &worker : entry.second.registered_workers) {
workers.push_back(worker);
if (worker->IsRegistered()) {
workers.push_back(worker);
}
}
}
return workers;
}
const std::vector<std::shared_ptr<Worker>> WorkerPool::GetAllDrivers() const {
const std::vector<std::shared_ptr<Worker>> WorkerPool::GetAllRegisteredDrivers() const {
std::vector<std::shared_ptr<Worker>> drivers;
for (const auto &entry : states_by_lang_) {
for (const auto &driver : entry.second.registered_drivers) {
drivers.push_back(driver);
if (driver->IsRegistered()) {
drivers.push_back(driver);
}
}
}
+32 -9
View File
@@ -18,6 +18,7 @@
#include <inttypes.h>
#include <boost/asio/io_service.hpp>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <vector>
@@ -54,13 +55,18 @@ class WorkerPool {
/// \param maximum_startup_concurrency The maximum number of worker processes
/// that can be started in parallel (typically this should be set to the number of CPU
/// resources on the machine).
/// \param min_worker_port The lowest port number that workers started will bind on.
/// If this is set to 0, workers will bind on random ports.
/// \param max_worker_port The highest port number that workers started will bind on.
/// If this is not set to 0, min_worker_port must also not be set to 0.
/// \param worker_commands The commands used to start the worker process, grouped by
/// language.
/// \param raylet_config The raylet config list of this node.
/// \param starting_worker_timeout_callback The callback that will be triggered once
/// it times out to start a worker.
WorkerPool(boost::asio::io_service &io_service, int num_workers,
int maximum_startup_concurrency, std::shared_ptr<gcs::GcsClient> gcs_client,
int maximum_startup_concurrency, int min_worker_port, int max_worker_port,
std::shared_ptr<gcs::GcsClient> gcs_client,
const WorkerCommandMap &worker_commands,
const std::unordered_map<std::string, std::string> &raylet_config,
std::function<void()> starting_worker_timeout_callback);
@@ -71,15 +77,20 @@ class WorkerPool {
/// Register a new worker. The Worker should be added by the caller to the
/// pool after it becomes idle (e.g., requests a work assignment).
///
/// \param The Worker to be registered.
/// \param[in] worker The worker to be registered.
/// \param[in] pid The PID of the worker.
/// \param[out] port The port that this worker's gRPC server should listen on.
/// Returns 0 if the worker should bind on a random port.
/// \return If the registration is successful.
Status RegisterWorker(const std::shared_ptr<Worker> &worker, pid_t pid);
Status RegisterWorker(const std::shared_ptr<Worker> &worker, pid_t pid, int *port);
/// Register a new driver.
///
/// \param The driver to be registered.
/// \param[in] worker The driver to be registered.
/// \param[out] port The port that this driver's gRPC server should listen on.
/// Returns 0 if the driver should bind on a random port.
/// \return If the registration is successful.
Status RegisterDriver(const std::shared_ptr<Worker> &worker);
Status RegisterDriver(const std::shared_ptr<Worker> &worker, int *port);
/// Get the client connection's registered worker.
///
@@ -135,15 +146,15 @@ class WorkerPool {
std::vector<std::shared_ptr<Worker>> GetWorkersRunningTasksForJob(
const JobID &job_id) const;
/// Get all the workers.
/// Get all the registered workers.
///
/// \return A list containing all the workers.
const std::vector<std::shared_ptr<Worker>> GetAllWorkers() const;
const std::vector<std::shared_ptr<Worker>> GetAllRegisteredWorkers() const;
/// Get all the drivers.
/// Get all the registered drivers.
///
/// \return A list containing all the drivers.
const std::vector<std::shared_ptr<Worker>> GetAllDrivers() const;
const std::vector<std::shared_ptr<Worker>> GetAllRegisteredDrivers() const;
/// Whether there is a pending worker for the given task.
/// Note that, this is only used for actor creation task with dynamic options.
@@ -248,10 +259,22 @@ class WorkerPool {
/// think there are unregistered workers, and won't start new workers.
void MonitorStartingWorkerProcess(const Process &proc, const Language &language);
/// Get the next unallocated port in the free ports list. If a port range isn't
/// configured, returns 0.
/// \param[out] port The next available port.
Status GetNextFreePort(int *port);
/// Mark this port as free to be used by another worker.
/// \param[in] port The port to mark as free.
void MarkPortAsFree(int port);
/// For Process class for managing subprocesses (e.g. reaping zombies).
boost::asio::io_service *io_service_;
/// The maximum number of worker processes that can be started concurrently.
int maximum_startup_concurrency_;
/// Keeps track of unused ports that newly-created workers can bind on.
/// If null, workers will not be passed ports and will choose them randomly.
std::unique_ptr<std::queue<int>> free_ports_;
/// A client connection to the GCS.
std::shared_ptr<gcs::GcsClient> gcs_client_;
/// The raylet config list of this node.
+45 -45
View File
@@ -40,8 +40,8 @@ class WorkerPoolMock : public WorkerPool {
explicit WorkerPoolMock(boost::asio::io_service &io_service,
const WorkerCommandMap &worker_commands)
: WorkerPool(io_service, 0, MAXIMUM_STARTUP_CONCURRENCY, nullptr, worker_commands,
{}, []() {}),
: WorkerPool(io_service, 0, MAXIMUM_STARTUP_CONCURRENCY, 0, 0, nullptr,
worker_commands, {}, []() {}),
last_worker_process_() {
states_by_lang_[ray::Language::JAVA].num_workers_per_process =
NUM_WORKERS_PER_PROCESS_JAVA;
@@ -96,10 +96,9 @@ class WorkerPoolMock : public WorkerPool {
class WorkerPoolTest : public ::testing::Test {
public:
WorkerPoolTest()
: worker_pool_(io_service_),
error_message_type_(1),
client_call_manager_(io_service_) {}
WorkerPoolTest() : error_message_type_(1), client_call_manager_(io_service_) {
worker_pool_ = std::unique_ptr<WorkerPoolMock>(new WorkerPoolMock(io_service_));
}
std::shared_ptr<Worker> CreateWorker(Process proc,
const Language &language = Language::PYTHON) {
@@ -115,7 +114,7 @@ class WorkerPoolTest : public ::testing::Test {
ClientConnection::Create(client_handler, message_handler, std::move(socket),
"worker", {}, error_message_type_);
std::shared_ptr<Worker> worker = std::make_shared<Worker>(
WorkerID::FromRandom(), language, "127.0.0.1", -1, client, client_call_manager_);
WorkerID::FromRandom(), language, "127.0.0.1", client, client_call_manager_);
if (!proc.IsNull()) {
worker->SetProcess(proc);
}
@@ -123,8 +122,8 @@ class WorkerPoolTest : public ::testing::Test {
}
void SetWorkerCommands(const WorkerCommandMap &worker_commands) {
WorkerPoolMock worker_pool(io_service_, worker_commands);
this->worker_pool_ = std::move(worker_pool);
worker_pool_ =
std::unique_ptr<WorkerPoolMock>(new WorkerPoolMock(io_service_, worker_commands));
}
void TestStartupWorkerProcessCount(Language language, int num_workers_per_process,
@@ -136,28 +135,28 @@ class WorkerPoolTest : public ::testing::Test {
static_cast<int>(desired_initial_worker_process_count));
Process last_started_worker_process;
for (int i = 0; i < desired_initial_worker_process_count; i++) {
worker_pool_.StartWorkerProcess(language);
ASSERT_TRUE(worker_pool_.NumWorkerProcessesStarting() <=
worker_pool_->StartWorkerProcess(language);
ASSERT_TRUE(worker_pool_->NumWorkerProcessesStarting() <=
expected_worker_process_count);
Process prev = worker_pool_.LastStartedWorkerProcess();
Process prev = worker_pool_->LastStartedWorkerProcess();
if (!std::equal_to<Process>()(last_started_worker_process, prev)) {
last_started_worker_process = prev;
const auto &real_command =
worker_pool_.GetWorkerCommand(last_started_worker_process);
worker_pool_->GetWorkerCommand(last_started_worker_process);
ASSERT_EQ(real_command, expected_worker_command);
} else {
ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(),
ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(),
expected_worker_process_count);
ASSERT_TRUE(i >= expected_worker_process_count);
}
}
// Check number of starting workers
ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), expected_worker_process_count);
ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), expected_worker_process_count);
}
protected:
boost::asio::io_service io_service_;
WorkerPoolMock worker_pool_;
std::unique_ptr<WorkerPoolMock> worker_pool_;
int64_t error_message_type_;
rpc::ClientCallManager client_call_manager_;
@@ -202,7 +201,7 @@ TEST_F(WorkerPoolTest, CompareWorkerProcessObjects) {
}
TEST_F(WorkerPoolTest, HandleWorkerRegistration) {
Process proc = worker_pool_.StartWorkerProcess(Language::JAVA);
Process proc = worker_pool_->StartWorkerProcess(Language::JAVA);
std::vector<std::shared_ptr<Worker>> workers;
for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) {
workers.push_back(CreateWorker(Process(), Language::JAVA));
@@ -210,19 +209,20 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) {
for (const auto &worker : workers) {
// Check that there's still a starting worker process
// before all workers have been registered
ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), 1);
ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), 1);
// Check that we cannot lookup the worker before it's registered.
ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), nullptr);
RAY_CHECK_OK(worker_pool_.RegisterWorker(worker, proc.GetId()));
ASSERT_EQ(worker_pool_->GetRegisteredWorker(worker->Connection()), nullptr);
int port;
RAY_CHECK_OK(worker_pool_->RegisterWorker(worker, proc.GetId(), &port));
// Check that we can lookup the worker after it's registered.
ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), worker);
ASSERT_EQ(worker_pool_->GetRegisteredWorker(worker->Connection()), worker);
}
// Check that there's no starting worker process
ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), 0);
ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), 0);
for (const auto &worker : workers) {
worker_pool_.DisconnectWorker(worker);
worker_pool_->DisconnectWorker(worker);
// Check that we cannot lookup the worker after it's disconnected.
ASSERT_EQ(worker_pool_.GetRegisteredWorker(worker->Connection()), nullptr);
ASSERT_EQ(worker_pool_->GetRegisteredWorker(worker->Connection()), nullptr);
}
}
@@ -239,21 +239,21 @@ TEST_F(WorkerPoolTest, StartupJavaWorkerProcessCount) {
}
TEST_F(WorkerPoolTest, InitialWorkerProcessCount) {
worker_pool_.Start(1);
worker_pool_->Start(1);
// Here we try to start only 1 worker for each worker language. But since each Java
// worker process contains exactly NUM_WORKERS_PER_PROCESS_JAVA (3) workers here,
// it's expected to see 3 workers for Java and 1 worker for Python, instead of 1 for
// each worker language.
ASSERT_NE(worker_pool_.NumWorkersStarting(), 1 * LANGUAGES.size());
ASSERT_EQ(worker_pool_.NumWorkersStarting(), 1 + NUM_WORKERS_PER_PROCESS_JAVA);
ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), LANGUAGES.size());
ASSERT_NE(worker_pool_->NumWorkersStarting(), 1 * LANGUAGES.size());
ASSERT_EQ(worker_pool_->NumWorkersStarting(), 1 + NUM_WORKERS_PER_PROCESS_JAVA);
ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), LANGUAGES.size());
}
TEST_F(WorkerPoolTest, HandleWorkerPushPop) {
// Try to pop a worker from the empty pool and make sure we don't get one.
std::shared_ptr<Worker> popped_worker;
const auto task_spec = ExampleTaskSpec();
popped_worker = worker_pool_.PopWorker(task_spec);
popped_worker = worker_pool_->PopWorker(task_spec);
ASSERT_EQ(popped_worker, nullptr);
// Create some workers.
@@ -262,17 +262,17 @@ TEST_F(WorkerPoolTest, HandleWorkerPushPop) {
workers.insert(CreateWorker(Process::CreateNewDummy()));
// Add the workers to the pool.
for (auto &worker : workers) {
worker_pool_.PushWorker(worker);
worker_pool_->PushWorker(worker);
}
// Pop two workers and make sure they're one of the workers we created.
popped_worker = worker_pool_.PopWorker(task_spec);
popped_worker = worker_pool_->PopWorker(task_spec);
ASSERT_NE(popped_worker, nullptr);
ASSERT_TRUE(workers.count(popped_worker) > 0);
popped_worker = worker_pool_.PopWorker(task_spec);
popped_worker = worker_pool_->PopWorker(task_spec);
ASSERT_NE(popped_worker, nullptr);
ASSERT_TRUE(workers.count(popped_worker) > 0);
popped_worker = worker_pool_.PopWorker(task_spec);
popped_worker = worker_pool_->PopWorker(task_spec);
ASSERT_EQ(popped_worker, nullptr);
}
@@ -280,21 +280,21 @@ TEST_F(WorkerPoolTest, PopActorWorker) {
// Create a worker.
auto worker = CreateWorker(Process::CreateNewDummy());
// Add the worker to the pool.
worker_pool_.PushWorker(worker);
worker_pool_->PushWorker(worker);
// Assign an actor ID to the worker.
const auto task_spec = ExampleTaskSpec();
auto actor = worker_pool_.PopWorker(task_spec);
auto actor = worker_pool_->PopWorker(task_spec);
const auto job_id = JobID::FromInt(1);
auto actor_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1);
actor->AssignActorId(actor_id);
worker_pool_.PushWorker(actor);
worker_pool_->PushWorker(actor);
// Check that there are no more non-actor workers.
ASSERT_EQ(worker_pool_.PopWorker(task_spec), nullptr);
ASSERT_EQ(worker_pool_->PopWorker(task_spec), nullptr);
// Check that we can pop the actor worker.
const auto actor_task_spec = ExampleTaskSpec(actor_id);
actor = worker_pool_.PopWorker(actor_task_spec);
actor = worker_pool_->PopWorker(actor_task_spec);
ASSERT_EQ(actor, worker);
ASSERT_EQ(actor->GetActorId(), actor_id);
}
@@ -302,19 +302,19 @@ TEST_F(WorkerPoolTest, PopActorWorker) {
TEST_F(WorkerPoolTest, PopWorkersOfMultipleLanguages) {
// Create a Python Worker, and add it to the pool
auto py_worker = CreateWorker(Process::CreateNewDummy(), Language::PYTHON);
worker_pool_.PushWorker(py_worker);
worker_pool_->PushWorker(py_worker);
// Check that no worker will be popped if the given task is a Java task
const auto java_task_spec = ExampleTaskSpec(ActorID::Nil(), Language::JAVA);
ASSERT_EQ(worker_pool_.PopWorker(java_task_spec), nullptr);
ASSERT_EQ(worker_pool_->PopWorker(java_task_spec), nullptr);
// Check that the worker can be popped if the given task is a Python task
const auto py_task_spec = ExampleTaskSpec(ActorID::Nil(), Language::PYTHON);
ASSERT_NE(worker_pool_.PopWorker(py_task_spec), nullptr);
ASSERT_NE(worker_pool_->PopWorker(py_task_spec), nullptr);
// Create a Java Worker, and add it to the pool
auto java_worker = CreateWorker(Process::CreateNewDummy(), Language::JAVA);
worker_pool_.PushWorker(java_worker);
worker_pool_->PushWorker(java_worker);
// Check that the worker will be popped now for Java task
ASSERT_NE(worker_pool_.PopWorker(java_task_spec), nullptr);
ASSERT_NE(worker_pool_->PopWorker(java_task_spec), nullptr);
}
TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) {
@@ -328,9 +328,9 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) {
TaskSpecification task_spec = ExampleTaskSpec(
ActorID::Nil(), Language::JAVA,
ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1), {"test_op_0", "test_op_1"});
worker_pool_.StartWorkerProcess(Language::JAVA, task_spec.DynamicWorkerOptions());
worker_pool_->StartWorkerProcess(Language::JAVA, task_spec.DynamicWorkerOptions());
const auto real_command =
worker_pool_.GetWorkerCommand(worker_pool_.LastStartedWorkerProcess());
worker_pool_->GetWorkerCommand(worker_pool_->LastStartedWorkerProcess());
ASSERT_EQ(real_command,
std::vector<std::string>(
{"test_op_0", "dummy_java_worker_command",