[Core] Multi-tenancy: Job isolation & implement per job config (except for env variables) (#9500)

This commit is contained in:
Kai Yang
2020-08-04 15:51:29 +08:00
committed by GitHub
parent 28b1f7710c
commit 27cd323ce1
35 changed files with 969 additions and 184 deletions
@@ -11,6 +11,7 @@ import io.ray.runtime.gcs.GcsClient;
import io.ray.runtime.gcs.GcsClientOptions;
import io.ray.runtime.gcs.RedisClient;
import io.ray.runtime.generated.Common.WorkerType;
import io.ray.runtime.generated.Gcs.JobConfig;
import io.ray.runtime.object.NativeObjectStore;
import io.ray.runtime.runner.RunManager;
import io.ray.runtime.task.NativeTaskExecutor;
@@ -106,6 +107,17 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
}
int numWorkersPerProcess =
rayConfig.workerMode == WorkerType.DRIVER ? 1 : rayConfig.numWorkersPerProcess;
byte[] serializedJobConfig = null;
if (rayConfig.workerMode == WorkerType.DRIVER) {
JobConfig.Builder jobConfigBuilder =
JobConfig.newBuilder()
.setNumJavaWorkersPerProcess(rayConfig.numWorkersPerProcess)
.addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker)
.putAllWorkerEnv(rayConfig.workerEnv);
serializedJobConfig = jobConfigBuilder.build().toByteArray();
}
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
nativeInitialize(rayConfig.workerMode.getNumber(),
rayConfig.nodeIp, rayConfig.getNodeManagerPort(),
@@ -113,7 +125,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName,
(rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(),
new GcsClientOptions(rayConfig), numWorkersPerProcess,
rayConfig.logDir, rayConfig.rayletConfigParameters);
rayConfig.logDir, rayConfig.rayletConfigParameters, serializedJobConfig);
taskExecutor = new NativeTaskExecutor(this);
workerContext = new NativeWorkerContext();
@@ -201,7 +213,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
int workerMode, String ndoeIpAddress,
int nodeManagerPort, String driverName, String storeSocket, String rayletSocket,
byte[] jobId, GcsClientOptions gcsClientOptions, int numWorkersPerProcess,
String logDir, Map<String, String> rayletConfigParameters);
String logDir, Map<String, String> rayletConfigParameters, byte[] serializedJobConfig);
private static native void nativeRunTaskExecutor(TaskExecutor taskExecutor);
@@ -3,6 +3,7 @@ package io.ray.runtime.config;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.typesafe.config.Config;
import com.typesafe.config.ConfigException;
import com.typesafe.config.ConfigFactory;
@@ -29,6 +30,8 @@ public class RayConfig {
public static final String DEFAULT_CONFIG_FILE = "ray.default.conf";
public static final String CUSTOM_CONFIG_FILE = "ray.conf";
private static int DEFAULT_NUM_JAVA_WORKER_PER_PROCESS = 10;
private static final Random RANDOM = new Random();
private static final DateTimeFormatter DATE_TIME_FORMATTER =
@@ -90,6 +93,9 @@ public class RayConfig {
public final int numWorkersPerProcess;
public final List<String> jvmOptionsForJavaWorker;
public final Map<String, String> workerEnv;
private void validate() {
if (workerMode == WorkerType.WORKER) {
Preconditions.checkArgument(redisAddress != null,
@@ -141,6 +147,17 @@ public class RayConfig {
this.jobId = JobId.NIL;
}
// jvm options for java workers of this job.
jvmOptionsForJavaWorker = config.getStringList("ray.job.jvm-options");
ImmutableMap.Builder<String, String> workerEnvBuilder = ImmutableMap.builder();
Config workerEnvConfig = config.getConfig("ray.job.worker-env");
if (workerEnvConfig != null) {
for (Map.Entry<String, ConfigValue> entry : workerEnvConfig.entrySet()) {
workerEnvBuilder.put(entry.getKey(), workerEnvConfig.getString(entry.getKey()));
}
}
workerEnv = workerEnvBuilder.build();
updateSessionDir();
// Object store configurations.
objectStoreSize = config.getBytes("ray.object-store.size");
@@ -206,7 +223,22 @@ public class RayConfig {
jobResourcePath = null;
}
numWorkersPerProcess = config.getInt("ray.raylet.config.num_workers_per_process_java");
boolean enableMultiTenancy = false;
if (config.hasPath("ray.raylet.config.enable_multi_tenancy")) {
enableMultiTenancy =
Boolean.valueOf(config.getString("ray.raylet.config.enable_multi_tenancy"));
}
if (!enableMultiTenancy) {
numWorkersPerProcess = config.getInt("ray.raylet.config.num_workers_per_process_java");
} else {
final int localNumWorkersPerProcess = config.getInt("ray.job.num-java-workers-per-process");
if (localNumWorkersPerProcess <= 0) {
numWorkersPerProcess = DEFAULT_NUM_JAVA_WORKER_PER_PROCESS;
} else {
numWorkersPerProcess = localNumWorkersPerProcess;
}
}
// Validate config.
validate();
@@ -29,6 +29,15 @@ ray {
// executing tasks from different jobs. E.g. if it's set to '/tm/job_resources',
// the path for job 123 will be '/tmp/job_resources/123'.
resource-path: ""
/// The number of java worker per worker process.
num-java-workers-per-process: 10
/// The jvm options for java workers of the job.
jvm-options: []
// Environment variables to be set on worker processes.
worker-env {
// key1 : "value1"
// key2 : "value2"
}
}
// Configurations about logging.
@@ -0,0 +1,69 @@
package io.ray.test;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
@Test(groups = {"cluster"})
public class JobConfigTest extends BaseTest {
@BeforeClass
public void setupJobConfig() {
System.setProperty("ray.raylet.config.enable_multi_tenancy", "true");
System.setProperty("ray.job.num-java-workers-per-process", "3");
System.setProperty("ray.job.jvm-options.0", "-DX=999");
}
@AfterClass
public void tearDownJobConfig() {
System.clearProperty("ray.raylet.config.enable_multi_tenancy");
System.clearProperty("ray.job.num-java-workers-per-process");
System.clearProperty("ray.job.jvm-options.0");
}
public static String getJvmOptions() {
return System.getProperty("X");
}
public static Integer getWorkersNum() {
return TestUtils.getRuntime().getRayConfig().numWorkersPerProcess;
}
public static class MyActor {
public Integer getWorkersNum() {
return TestUtils.getRuntime().getRayConfig().numWorkersPerProcess;
}
public String getJvmOptions() {
return System.getProperty("X");
}
}
public void testJvmOptions() {
ObjectRef<String> obj = Ray.task(JobConfigTest::getJvmOptions).remote();
Assert.assertEquals("999", obj.get());
}
public void testNumJavaWorkerPerProcess() {
ObjectRef<Integer> obj = Ray.task(JobConfigTest::getWorkersNum).remote();
Assert.assertEquals(3, (int) obj.get());
}
public void testInActor() {
ActorHandle<MyActor> actor = Ray.actor(MyActor::new).remote();
// test jvm options.
ObjectRef<String> obj1 = actor.task(MyActor::getJvmOptions).remote();
Assert.assertEquals("999", obj1.get());
// test workers number.
ObjectRef<Integer> obj2 = actor.task(MyActor::getWorkersNum).remote();
Assert.assertEquals(3, (int) obj2.get());
}
}
@@ -0,0 +1,130 @@
package io.ray.test;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.util.SystemUtil;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.lang.ProcessBuilder.Redirect;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
@Test(groups = {"cluster"})
public class MultiDriverTest extends BaseTest {
private static final int DRIVER_COUNT = 10;
private static final int NORMAL_TASK_COUNT_PER_DRIVER = 100;
private static final int ACTOR_COUNT_PER_DRIVER = 10;
private static final String PID_LIST_PREFIX = "PID: ";
@BeforeClass
public void setUpClass() {
System.setProperty("ray.raylet.config.enable_multi_tenancy", "true");
}
@AfterClass
public void tearDownClass() {
System.clearProperty("ray.raylet.config.enable_multi_tenancy");
}
static int getPid() {
return SystemUtil.pid();
}
public static class Actor {
public int getPid() {
return SystemUtil.pid();
}
}
public static void main(String[] args) throws IOException {
Ray.init();
List<ObjectRef<Integer>> pidObjectList = new ArrayList<>();
// Submit some normal tasks and get the PIDs of workers which execute the tasks.
for (int i = 0; i < NORMAL_TASK_COUNT_PER_DRIVER; ++i) {
pidObjectList.add(Ray.task(MultiDriverTest::getPid).remote());
}
// Create some actors and get the PIDs of actors.
for (int i = 0; i < ACTOR_COUNT_PER_DRIVER; ++i) {
ActorHandle<Actor> actor = Ray.actor(Actor::new).remote();
pidObjectList.add(actor.task(Actor::getPid).remote());
}
Set<Integer> pids = new HashSet<>();
for (ObjectRef<Integer> object : pidObjectList) {
pids.add(object.get());
}
// Write pids to stdout
System.out.println(
PID_LIST_PREFIX + pids.stream().map(String::valueOf).collect(Collectors.joining(",")));
}
public void testMultiDrivers() throws InterruptedException, IOException {
// This test case starts some driver processes. Each driver process submits some tasks and
// collect the PIDs of the workers used by the driver. The drivers output the PID list
// which will be read by the test case itself. The test case will compare the PIDs used by
// different drivers and make sure that all the PIDs don't overlap. If overlapped, it means that
// tasks owned by different drivers were scheduled to the same worker process, that is, tasks
// of different jobs were not correctly isolated during execution.
List<Process> drivers = new ArrayList<>();
for (int i = 0; i < DRIVER_COUNT; ++i) {
drivers.add(startDriver());
}
// Wait for drivers to finish.
for (Process driver : drivers) {
driver.waitFor();
Assert.assertEquals(driver.exitValue(), 0,
"The driver exited with code " + driver.exitValue());
}
// Read driver outputs and check for any PID overlap.
Set<Integer> pids = new HashSet<>();
for (Process driver : drivers) {
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(driver.getInputStream()))) {
String line;
int previousSize = pids.size();
while ((line = reader.readLine()) != null) {
if (line.startsWith(PID_LIST_PREFIX)) {
for (String pidString : line.substring(PID_LIST_PREFIX.length()).split(",")) {
// Make sure the PIDs don't overlap.
Assert.assertTrue(pids.add(Integer.valueOf(pidString)),
"Worker process with PID " + line + " is shared by multiple drivers.");
}
break;
}
}
int nowSize = pids.size();
Assert.assertTrue(nowSize > previousSize);
}
}
}
private Process startDriver() throws IOException {
RayConfig rayConfig = TestUtils.getRuntime().getRayConfig();
ProcessBuilder builder = new ProcessBuilder(
"java",
"-cp",
System.getProperty("java.class.path"),
"-Dray.redis.address=" + rayConfig.getRedisAddress(),
"-Dray.object-store.socket-name=" + rayConfig.objectStoreSocketName,
"-Dray.raylet.socket-name=" + rayConfig.rayletSocketName,
"-Dray.raylet.node-manager-port=" + String.valueOf(rayConfig.getNodeManagerPort()),
MultiDriverTest.class.getName());
builder.redirectError(Redirect.INHERIT);
return builder.start();
}
}
@@ -9,8 +9,8 @@ import org.testng.annotations.Test;
public class RayletConfigTest extends BaseTest {
private static final String RAY_CONFIG_KEY = "num_workers_per_process_java";
private static final String RAY_CONFIG_VALUE = "2";
private static final String RAY_CONFIG_KEY = "get_timeout_milliseconds";
private static final String RAY_CONFIG_VALUE = "1234";
@BeforeClass
public void beforeClass() {
+3 -1
View File
@@ -658,7 +658,8 @@ cdef class CoreWorker:
def __cinit__(self, is_driver, store_socket, raylet_socket,
JobID job_id, GcsClientOptions gcs_options, log_dir,
node_ip_address, node_manager_port, raylet_ip_address,
local_mode, driver_name, stdout_file, stderr_file):
local_mode, driver_name, stdout_file, stderr_file,
serialized_job_config):
self.is_driver = is_driver
self.is_local_mode = local_mode
@@ -688,6 +689,7 @@ cdef class CoreWorker:
options.num_workers = 1
options.kill_main = kill_main_task
options.terminate_asyncio_thread = terminate_asyncio_thread
options.serialized_job_config = serialized_job_config
CCoreWorkerProcess.Initialize(options)
+2
View File
@@ -3,6 +3,7 @@ from ray.core.generated.gcs_pb2 import (
ActorTableData,
GcsNodeInfo,
JobTableData,
JobConfig,
ErrorTableData,
ErrorType,
GcsEntry,
@@ -25,6 +26,7 @@ __all__ = [
"ActorTableData",
"GcsNodeInfo",
"JobTableData",
"JobConfig",
"ErrorTableData",
"ErrorType",
"GcsEntry",
+1
View File
@@ -226,6 +226,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
(c_bool() nogil) kill_main
CCoreWorkerOptions()
(void() nogil) terminate_asyncio_thread
c_string serialized_job_config
cdef cppclass CCoreWorkerProcess "ray::CoreWorkerProcess":
@staticmethod
+38
View File
@@ -0,0 +1,38 @@
import ray
class JobConfig:
"""A class used to store the configurations of a job.
Attributes:
worker_env (dict): Environment variables to be set on worker
processes.
num_java_workers_per_process (int): The number of java workers per
worker process.
jvm_options (str[]): The jvm options for java workers of the job.
"""
def __init__(
self,
worker_env=None,
num_java_workers_per_process=10,
jvm_options=None,
):
if worker_env is None:
self.worker_env = dict()
else:
self.worker_env = worker_env
self.num_java_workers_per_process = num_java_workers_per_process
if jvm_options is None:
self.jvm_options = []
else:
self.jvm_options = jvm_options
def serialize(self):
job_config = ray.gcs_utils.JobConfig()
for key in self.worker_env:
job_config.worker_env[key] = self.worker_env[key]
job_config.num_java_workers_per_process = (
self.num_java_workers_per_process)
job_config.jvm_options.extend(self.jvm_options)
return job_config.SerializeToString()
+3 -1
View File
@@ -679,7 +679,9 @@ class Node:
huge_pages=self._ray_params.huge_pages,
fate_share=self.kernel_fate_share,
socket_to_use=self.socket,
head_node=self.head)
head_node=self.head,
start_initial_python_workers_for_first_job=self._ray_params.
start_initial_python_workers_for_first_job)
assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes
self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info]
+5
View File
@@ -94,6 +94,8 @@ class RayParams:
lru_evict (bool): Enable LRU eviction if space is needed.
enable_object_reconstruction (bool): Enable plasma reconstruction on
failure.
start_initial_python_workers_for_first_job (bool): If true, start
initial Python workers for the first job on the node.
"""
def __init__(self,
@@ -136,6 +138,7 @@ class RayParams:
include_java=False,
java_worker_options=None,
load_code_from_local=False,
start_initial_python_workers_for_first_job=False,
_internal_config=None,
enable_object_reconstruction=False,
metrics_agent_port=None,
@@ -178,6 +181,8 @@ class RayParams:
self.java_worker_options = java_worker_options
self.load_code_from_local = load_code_from_local
self.metrics_agent_port = metrics_agent_port
self.start_initial_python_workers_for_first_job = (
start_initial_python_workers_for_first_job)
self._internal_config = _internal_config
self._lru_evict = lru_evict
self._enable_object_reconstruction = enable_object_reconstruction
+5 -1
View File
@@ -1271,7 +1271,8 @@ def start_raylet(redis_address,
huge_pages=False,
fate_share=None,
socket_to_use=None,
head_node=False):
head_node=False,
start_initial_python_workers_for_first_job=False):
"""Start a raylet, which is a combined local scheduler and object manager.
Args:
@@ -1403,6 +1404,9 @@ def start_raylet(redis_address,
"--session_dir={}".format(session_dir),
"--metrics-agent-port={}".format(metrics_agent_port),
]
if start_initial_python_workers_for_first_job:
command.append("--num_initial_python_workers_for_first_job={}".format(
resource_spec.num_cpus))
if config.get("plasma_store_as_thread"):
# command related to the plasma store
plasma_directory, object_store_memory = determine_plasma_store_config(
+8
View File
@@ -111,6 +111,14 @@ py_test(
deps = ["//:ray_lib"],
)
py_test(
name = "test_multi_tenancy",
size = "small",
srcs = SRCS + ["test_multi_tenancy.py"],
tags = ["exclusive"],
deps = ["//:ray_lib"],
)
py_test(
name = "test_component_failures",
size = "small",
+116
View File
@@ -0,0 +1,116 @@
# coding: utf-8
import json
import sys
import grpc
import pytest
import ray
import ray.test_utils
from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc
from ray.test_utils import wait_for_condition, run_string_as_driver_nonblocking
# Test that when `redis_address` and `job_config` is not set in
# `ray.init(...)`, Raylet will start `num_cpus` Python workers for the driver.
def test_initial_workers(shutdown_only):
# `num_cpus` should be <=2 because a Travis CI machine only has 2 CPU cores
ray.init(
num_cpus=1,
include_dashboard=True,
_internal_config=json.dumps({
"enable_multi_tenancy": True
}))
raylet = ray.nodes()[0]
raylet_address = "{}:{}".format(raylet["NodeManagerAddress"],
raylet["NodeManagerPort"])
channel = grpc.insecure_channel(raylet_address)
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
wait_for_condition(lambda: len([
worker for worker in stub.GetNodeStats(
node_manager_pb2.GetNodeStatsRequest()).workers_stats
if not worker.is_driver
]) == 1,
timeout=10)
# This test case starts some driver processes. Each driver process submits
# some tasks and collect the PIDs of the workers used by the driver. The
# drivers output the PID list which will be read by the test case itself. The
# test case will compare the PIDs used by different drivers and make sure that
# all the PIDs don't overlap. If overlapped, it means that tasks owned by
# different drivers were scheduled to the same worker process, that is, tasks
# of different jobs were not correctly isolated during execution.
def test_multi_drivers(shutdown_only):
info = ray.init(
_internal_config=json.dumps({
"enable_multi_tenancy": True
}))
driver_code = """
import os
import sys
import ray
ray.init(address="{}")
@ray.remote
class Actor:
def get_pid(self):
return os.getpid()
@ray.remote
def get_pid():
return os.getpid()
pid_objs = []
# Submit some normal tasks and get the PIDs of workers which execute the tasks.
pid_objs = pid_objs + [get_pid.remote() for _ in range(5)]
# Create some actors and get the PIDs of actors.
actors = [Actor.remote() for _ in range(5)]
pid_objs = pid_objs + [actor.get_pid.remote() for actor in actors]
pids = set([ray.get(obj) for obj in pid_objs])
# Write pids to stdout
print("PID:" + str.join(",", [str(_) for _ in pids]))
ray.shutdown()
""".format(info["redis_address"])
driver_count = 10
processes = [
run_string_as_driver_nonblocking(driver_code)
for _ in range(driver_count)
]
outputs = []
for p in processes:
out = p.stdout.read().decode("ascii")
err = p.stderr.read().decode("ascii")
p.wait()
# out, err = p.communicate()
# out = ray.utils.decode(out)
# err = ray.utils.decode(err)
if p.returncode != 0:
print("Driver with PID {} returned error code {}".format(
p.pid, p.returncode))
print("STDOUT:\n{}".format(out))
print("STDERR:\n{}".format(err))
outputs.append((p, out))
all_worker_pids = set()
for p, out in outputs:
assert p.returncode == 0
for line in out.splitlines():
if line.startswith("PID:"):
worker_pids = [int(_) for _ in line.split(":")[1].split(",")]
assert len(worker_pids) > 0
for worker_pid in worker_pids:
assert worker_pid not in all_worker_pids, (
("Worker process with PID {} is shared" +
" by multiple drivers.").format(worker_pid))
all_worker_pids.add(worker_pid)
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
+13 -3
View File
@@ -20,6 +20,7 @@ import ray.cloudpickle as pickle
import ray.gcs_utils
import ray.memory_monitor as memory_monitor
import ray.node
import ray.job_config
import ray.parameter
import ray.ray_constants as ray_constants
import ray.remote_function
@@ -486,6 +487,7 @@ def init(address=None,
dashboard_host="localhost",
dashboard_port=ray_constants.DEFAULT_DASHBOARD_PORT,
job_id=None,
job_config=None,
configure_logging=True,
logging_level=logging.INFO,
logging_format=ray_constants.LOGGER_FORMAT,
@@ -588,6 +590,7 @@ def init(address=None,
dashboard_port: The port to bind the dashboard server to. Defaults to
8265.
job_id: The ID of this job.
job_config (ray.job_config.JobConfig): The job configuration.
configure_logging: True (default) if configuration of logging is
allowed here. Otherwise, the user may want to configure it
separately.
@@ -713,6 +716,7 @@ def init(address=None,
temp_dir=temp_dir,
load_code_from_local=load_code_from_local,
java_worker_options=java_worker_options,
start_initial_python_workers_for_first_job=True,
_internal_config=_internal_config,
lru_evict=lru_evict,
enable_object_reconstruction=enable_object_reconstruction)
@@ -803,7 +807,8 @@ def init(address=None,
log_to_driver=log_to_driver,
worker=global_worker,
driver_object_store_memory=driver_object_store_memory,
job_id=job_id)
job_id=job_id,
job_config=job_config)
for hook in _post_init_hooks:
hook()
@@ -1147,7 +1152,8 @@ def connect(node,
log_to_driver=False,
worker=global_worker,
driver_object_store_memory=None,
job_id=None):
job_id=None,
job_config=None):
"""Connect this worker to the raylet, to Plasma, and to Redis.
Args:
@@ -1160,6 +1166,7 @@ def connect(node,
driver_object_store_memory: Limit the amount of memory the driver can
use in the object store when creating objects.
job_id: The ID of job. If it's None, then we will generate one.
job_config (ray.job_config.JobConfig): The job configuration.
"""
# Do some basic checking to make sure we didn't call ray.init twice.
error_message = "Perhaps you called ray.init twice by accident?"
@@ -1265,7 +1272,9 @@ def connect(node,
int(redis_port),
node.redis_password,
)
if job_config is None:
job_config = ray.job_config.JobConfig()
serialized_job_config = job_config.serialize()
worker.core_worker = ray._raylet.CoreWorker(
(mode == SCRIPT_MODE or mode == LOCAL_MODE),
node.plasma_store_socket_name,
@@ -1280,6 +1289,7 @@ def connect(node,
driver_name,
log_stdout_file_path,
log_stderr_file_path,
serialized_job_config,
)
# Create an object for interfacing with the global state.
+3
View File
@@ -337,3 +337,6 @@ RAY_CONFIG(uint32_t, max_tasks_in_flight_per_worker, 1)
/// The maximum number of resource shapes included in the resource
/// load reported by each raylet.
RAY_CONFIG(int64_t, max_resource_shapes_per_load_report, 100)
/// Whether to enable multi tenancy features.
RAY_CONFIG(bool, enable_multi_tenancy, false)
+2 -1
View File
@@ -311,7 +311,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
io_service_, std::move(grpc_client), options_.raylet_socket, GetWorkerID(),
(options_.worker_type == ray::WorkerType::WORKER),
worker_context_.GetCurrentJobID(), options_.language, options_.node_ip_address,
&local_raylet_id, &assigned_port, &internal_config));
&local_raylet_id, &assigned_port, &internal_config,
options_.serialized_job_config));
connected_ = true;
RAY_CHECK(assigned_port != -1)
+2
View File
@@ -119,6 +119,8 @@ struct CoreWorkerOptions {
int num_workers;
/// The function to destroy asyncio event and loops.
std::function<void()> terminate_asyncio_thread;
/// Serialized representation of JobConfig.
std::string serialized_job_config;
};
/// Lifecycle management of one or more `CoreWorker` instances in a process.
@@ -93,7 +93,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
JNIEnv *env, jclass, jint workerMode, jstring nodeIpAddress, jint nodeManagerPort,
jstring driverName, jstring storeSocket, jstring rayletSocket, jbyteArray jobId,
jobject gcsClientOptions, jint numWorkersPerProcess, jstring logDir,
jobject rayletConfigParameters) {
jobject rayletConfigParameters, jbyteArray jobConfig) {
auto raylet_config = JavaMapToNativeMap<std::string, std::string>(
env, rayletConfigParameters,
[](JNIEnv *env, jobject java_key) {
@@ -203,6 +203,8 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
}
};
std::string serialized_job_config =
(jobConfig == nullptr ? "" : JavaByteArrayToNativeString(env, jobConfig));
ray::CoreWorkerOptions options = {
static_cast<ray::WorkerType>(workerMode), // worker_type
ray::Language::JAVA, // langauge
@@ -228,6 +230,8 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
true, // ref_counting_enabled
false, // is_local_mode
static_cast<int>(numWorkersPerProcess), // num_workers
nullptr, // terminate_asyncio_thread
serialized_job_config, // serialized_job_config
};
ray::CoreWorkerProcess::Initialize(options);
@@ -29,7 +29,7 @@ extern "C" {
*/
JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
JNIEnv *, jclass, jint, jstring, jint, jstring, jstring, jstring, jbyteArray, jobject,
jint, jstring, jobject);
jint, jstring, jobject, jbyteArray);
/*
* Class: io_ray_runtime_RayNativeRuntime
+9 -1
View File
@@ -241,6 +241,14 @@ class JavaByteArrayBuffer : public ray::Buffer {
jbyte *native_bytes_;
};
/// Convert a Java byte array to a C++ string.
inline std::string JavaByteArrayToNativeString(JNIEnv *env, const jbyteArray &bytes) {
const auto size = env->GetArrayLength(bytes);
std::string str(size, 0);
env->GetByteArrayRegion(bytes, 0, size, reinterpret_cast<jbyte *>(&str.front()));
return str;
}
/// Convert a Java byte array to a C++ UniqueID.
template <typename ID>
inline ID JavaByteArrayToId(JNIEnv *env, const jbyteArray &bytes) {
@@ -516,4 +524,4 @@ inline std::string GetActorFullName(bool global, std::string name) {
return global ? name
: ::ray::CoreWorkerProcess::GetCoreWorker().GetCurrentJobId().Hex() +
"-" + name;
}
}
+2 -2
View File
@@ -35,14 +35,14 @@ namespace gcs {
inline std::shared_ptr<ray::rpc::JobTableData> CreateJobTableData(
const ray::JobID &job_id, bool is_dead, int64_t timestamp,
const std::string &driver_ip_address, int64_t driver_pid,
const ray::rpc::JobConfigs &job_configs = {}) {
const ray::rpc::JobConfig &job_config = {}) {
auto job_info_ptr = std::make_shared<ray::rpc::JobTableData>();
job_info_ptr->set_job_id(job_id.Binary());
job_info_ptr->set_is_dead(is_dead);
job_info_ptr->set_timestamp(timestamp);
job_info_ptr->set_driver_ip_address(driver_ip_address);
job_info_ptr->set_driver_pid(driver_pid);
*job_info_ptr->mutable_configs() = job_configs;
*job_info_ptr->mutable_config() = job_config;
return job_info_ptr;
}
+6 -12
View File
@@ -309,19 +309,13 @@ message TaskLeaseData {
uint64 timeout = 4;
}
message JobConfigs {
// The initial Python workers to start per node. If a negative value is specified, it
// fallbacks to `num_cpus`.
int32 num_initial_python_workers = 1;
// The initial Java workers to start per node. If a negative value is specified, it
// fallbacks to `num_cpus`.
int32 num_initial_java_workers = 2;
message JobConfig {
// Environment variables to be set on worker processes.
map<string, string> worker_env = 3;
map<string, string> worker_env = 1;
// The number of java workers per worker process.
uint32 num_java_workers_per_process = 4;
uint32 num_java_workers_per_process = 2;
// The jvm options for java workers of the job.
repeated string jvm_options = 5;
repeated string jvm_options = 3;
}
message JobTableData {
@@ -335,8 +329,8 @@ message JobTableData {
string driver_ip_address = 4;
// Process ID of the driver running this job.
int64 driver_pid = 5;
// The configs of this job.
JobConfigs configs = 6;
// The config of this job.
JobConfig config = 6;
}
// This table stores the actor checkpoint data. An actor checkpoint
+3 -1
View File
@@ -157,6 +157,8 @@ table RegisterClientRequest {
ip_address: string;
// Port that this worker is listening on.
port: int;
// The config bytes of this job serialized with protobuf.
serialized_job_config: string;
}
table RegisterClientReply {
@@ -311,6 +313,6 @@ table SetResourceRequest {
}
table SubscribePlasmaReady {
// ObjectID to wait for
// ObjectID to wait for
object_id: string;
}
+6
View File
@@ -36,6 +36,8 @@ DEFINE_int32(min_worker_port, 0,
DEFINE_int32(max_worker_port, 0,
"The highest port that workers' gRPC servers will bind on.");
DEFINE_int32(num_initial_workers, 0, "Number of initial workers.");
DEFINE_int32(num_initial_python_workers_for_first_job, 0,
"Number of initial Python workers for the first job.");
DEFINE_int32(maximum_startup_concurrency, 1, "Maximum startup concurrency");
DEFINE_string(static_resource_list, "", "The static resource list of this node.");
DEFINE_string(config_list, "", "The raylet config list of this node.");
@@ -71,6 +73,8 @@ int main(int argc, char *argv[]) {
const int min_worker_port = static_cast<int>(FLAGS_min_worker_port);
const int max_worker_port = static_cast<int>(FLAGS_max_worker_port);
const int num_initial_workers = static_cast<int>(FLAGS_num_initial_workers);
const int num_initial_python_workers_for_first_job =
static_cast<int>(FLAGS_num_initial_python_workers_for_first_job);
const int maximum_startup_concurrency =
static_cast<int>(FLAGS_maximum_startup_concurrency);
const std::string static_resource_list = FLAGS_static_resource_list;
@@ -160,6 +164,8 @@ int main(int argc, char *argv[]) {
node_manager_config.node_manager_address = node_ip_address;
node_manager_config.node_manager_port = node_manager_port;
node_manager_config.num_initial_workers = num_initial_workers;
node_manager_config.num_initial_python_workers_for_first_job =
num_initial_python_workers_for_first_job;
node_manager_config.maximum_startup_concurrency = maximum_startup_concurrency;
node_manager_config.min_worker_port = min_worker_port;
node_manager_config.max_worker_port = max_worker_port;
+49 -39
View File
@@ -162,9 +162,11 @@ NodeManager::NodeManager(boost::asio::io_service &io_service,
initial_config_(config),
local_available_resources_(config.resource_config),
worker_pool_(
io_service, config.num_initial_workers, config.maximum_startup_concurrency,
config.min_worker_port, config.max_worker_port, gcs_client_,
config.worker_commands, config.raylet_config,
io_service, config.num_initial_workers,
config.num_initial_python_workers_for_first_job,
config.maximum_startup_concurrency, config.min_worker_port,
config.max_worker_port, gcs_client_, config.worker_commands,
config.raylet_config,
/*starting_worker_timeout_callback=*/
[this]() { this->DispatchTasks(this->local_queues_.GetReadyTasksByClass()); }),
scheduling_policy_(local_queues_),
@@ -347,12 +349,20 @@ void NodeManager::HandleJobStarted(const JobID &job_id, const JobTableData &job_
RAY_LOG(DEBUG) << "HandleJobStarted " << job_id;
RAY_CHECK(!job_data.is_dead());
// TODO(kfstorm): Spawn job initial workers in a later PR.
worker_pool_.HandleJobStarted(job_id, job_data.config());
if (RayConfig::instance().enable_multi_tenancy()) {
// Tasks of this job may already arrived but failed to pop a worker because the job
// config is not local yet. So we trigger dispatching again here to try to
// reschedule these tasks.
DispatchTasks(local_queues_.GetReadyTasksByClass());
}
}
void NodeManager::HandleJobFinished(const JobID &job_id, const JobTableData &job_data) {
RAY_LOG(DEBUG) << "HandleJobFinished " << job_id;
RAY_CHECK(job_data.is_dead());
worker_pool_.HandleJobFinished(job_id);
auto workers = worker_pool_.GetWorkersRunningTasksForJob(job_id);
// Kill all the workers. The actual cleanup for these workers is done
// later when we receive the DisconnectClient message from them.
@@ -1218,13 +1228,31 @@ void NodeManager::ProcessRegisterClientRequestMessage(
auto worker = std::dynamic_pointer_cast<WorkerInterface>(std::make_shared<Worker>(
worker_id, language, worker_ip_address, client, client_call_manager_));
int assigned_port;
auto send_reply_callback = [this, client](int assigned_port) {
flatbuffers::FlatBufferBuilder fbb;
std::vector<std::string> internal_config_keys;
std::vector<std::string> internal_config_values;
for (auto kv : initial_config_.raylet_config) {
internal_config_keys.push_back(kv.first);
internal_config_values.push_back(kv.second);
}
auto reply = ray::protocol::CreateRegisterClientReply(
fbb, to_flatbuf(fbb, self_node_id_), assigned_port,
string_vec_to_flatbuf(fbb, internal_config_keys),
string_vec_to_flatbuf(fbb, internal_config_values));
fbb.Finish(reply);
client->WriteMessageAsync(
static_cast<int64_t>(protocol::MessageType::RegisterClientReply), fbb.GetSize(),
fbb.GetBufferPointer(), [this, client](const ray::Status &status) {
if (!status.ok()) {
ProcessDisconnectClientMessage(client);
}
});
};
if (message->is_worker()) {
// Register the new worker.
if (!worker_pool_.RegisterWorker(worker, pid, &assigned_port).ok()) {
// Return -1 to signal to the worker that registration failed.
assigned_port = -1;
}
RAY_UNUSED(worker_pool_.RegisterWorker(worker, pid, send_reply_callback));
} else {
// Register the new driver.
RAY_CHECK(pid >= 0);
@@ -1233,38 +1261,18 @@ void NodeManager::ProcessRegisterClientRequestMessage(
// Compute a dummy driver task id from a given driver.
const TaskID driver_task_id = TaskID::ComputeDriverTaskId(worker_id);
worker->AssignTaskId(driver_task_id);
worker->AssignJobId(job_id);
Status status = worker_pool_.RegisterDriver(worker, &assigned_port);
rpc::JobConfig job_config;
job_config.ParseFromString(message->serialized_job_config()->str());
Status status =
worker_pool_.RegisterDriver(worker, job_id, job_config, send_reply_callback);
if (status.ok()) {
local_queues_.AddDriverTaskId(driver_task_id);
auto job_data_ptr = gcs::CreateJobTableData(
job_id, /*is_dead*/ false, std::time(nullptr), worker_ip_address, pid);
auto job_data_ptr =
gcs::CreateJobTableData(job_id, /*is_dead*/ false, std::time(nullptr),
worker_ip_address, pid, job_config);
RAY_CHECK_OK(gcs_client_->Jobs().AsyncAdd(job_data_ptr, nullptr));
} else {
// Return -1 to signal to the worker that registration failed.
assigned_port = -1;
}
}
flatbuffers::FlatBufferBuilder fbb;
std::vector<std::string> internal_config_keys;
std::vector<std::string> internal_config_values;
for (auto kv : initial_config_.raylet_config) {
internal_config_keys.push_back(kv.first);
internal_config_values.push_back(kv.second);
}
auto reply = ray::protocol::CreateRegisterClientReply(
fbb, to_flatbuf(fbb, self_node_id_), assigned_port,
string_vec_to_flatbuf(fbb, internal_config_keys),
string_vec_to_flatbuf(fbb, internal_config_values));
fbb.Finish(reply);
client->WriteMessageAsync(
static_cast<int64_t>(protocol::MessageType::RegisterClientReply), fbb.GetSize(),
fbb.GetBufferPointer(), [this, client](const ray::Status &status) {
if (!status.ok()) {
ProcessDisconnectClientMessage(client);
}
});
}
void NodeManager::ProcessAnnounceWorkerPortMessage(
@@ -2690,9 +2698,11 @@ bool NodeManager::FinishAssignedTask(WorkerInterface &worker) {
task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId());
task_dependency_manager_.TaskCanceled(task_id);
// Unset the worker's assigned job Id if this is not an actor.
if (!spec.IsActorCreationTask() && !spec.IsActorTask()) {
worker.AssignJobId(JobID::Nil());
if (!RayConfig::instance().enable_multi_tenancy()) {
// Unset the worker's assigned job Id if this is not an actor.
if (!spec.IsActorCreationTask() && !spec.IsActorTask()) {
worker.AssignJobId(JobID::Nil());
}
}
if (!spec.IsActorCreationTask()) {
// Unset the worker's assigned task. We keep the assigned task ID for
+3
View File
@@ -67,6 +67,8 @@ struct NodeManagerConfig {
int max_worker_port;
/// The initial number of workers to create.
int num_initial_workers;
/// Number of initial Python workers for the first job.
int num_initial_python_workers_for_first_job;
/// The maximum number of workers that can be started concurrently by a
/// worker pool.
int maximum_startup_concurrency;
@@ -731,6 +733,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
/// The resources (and specific resource IDs) that are currently available.
ResourceIdSet local_available_resources_;
std::unordered_map<ClientID, SchedulingResources> cluster_resource_map_;
/// A pool of workers.
WorkerPool worker_pool_;
/// A set of queues to maintain tasks.
@@ -125,7 +125,9 @@ void ClusterTaskManager::DispatchScheduledTasksToWorkers(
worker->SetAllocatedInstances(allocated_instances);
}
worker->AssignTaskId(spec.TaskId());
worker->AssignJobId(spec.JobId());
if (!RayConfig::instance().enable_multi_tenancy()) {
worker->AssignJobId(spec.JobId());
}
worker->SetAssignedTask(task);
Dispatch(worker, leased_workers, spec, reply, callback);
}
+14 -1
View File
@@ -106,7 +106,20 @@ const std::unordered_set<TaskID> &Worker::GetBlockedTaskIds() const {
return blocked_task_ids_;
}
void Worker::AssignJobId(const JobID &job_id) { assigned_job_id_ = job_id; }
void Worker::AssignJobId(const JobID &job_id) {
if (!RayConfig::instance().enable_multi_tenancy()) {
assigned_job_id_ = job_id;
} else {
if (!assigned_job_id_.IsNil()) {
RAY_CHECK(assigned_job_id_ == job_id)
<< "The worker " << worker_id_ << " is already assigned to job "
<< assigned_job_id_ << ". It cannot be reassigned to job " << job_id;
} else {
assigned_job_id_ = job_id;
RAY_LOG(INFO) << "Assigned worker " << worker_id_ << " to job " << job_id;
}
}
}
const JobID &Worker::GetAssignedJobId() const { return assigned_job_id_; }
+209 -58
View File
@@ -54,9 +54,8 @@ namespace ray {
namespace raylet {
/// A constructor that initializes a worker pool with num_workers workers for
/// each language.
WorkerPool::WorkerPool(boost::asio::io_service &io_service, int num_workers,
int num_initial_python_workers_for_first_job,
int maximum_startup_concurrency, int min_worker_port,
int max_worker_port, std::shared_ptr<gcs::GcsClient> gcs_client,
const WorkerCommandMap &worker_commands,
@@ -66,7 +65,12 @@ WorkerPool::WorkerPool(boost::asio::io_service &io_service, int num_workers,
maximum_startup_concurrency_(maximum_startup_concurrency),
gcs_client_(std::move(gcs_client)),
raylet_config_(raylet_config),
starting_worker_timeout_callback_(starting_worker_timeout_callback) {
starting_worker_timeout_callback_(starting_worker_timeout_callback),
first_job_registered_python_worker_count_(0),
first_job_driver_wait_num_python_workers_(std::min(
num_initial_python_workers_for_first_job, maximum_startup_concurrency)),
num_initial_python_workers_for_first_job_(
num_initial_python_workers_for_first_job) {
RAY_CHECK(maximum_startup_concurrency > 0);
#ifndef _WIN32
// Ignore SIGCHLD signals. If we don't do this, then worker processes will
@@ -76,25 +80,29 @@ WorkerPool::WorkerPool(boost::asio::io_service &io_service, int num_workers,
for (const auto &entry : worker_commands) {
// Initialize the pool state for this language.
auto &state = states_by_lang_[entry.first];
switch (entry.first) {
case Language::PYTHON:
state.num_workers_per_process =
RayConfig::instance().num_workers_per_process_python();
break;
case Language::JAVA:
state.num_workers_per_process =
RayConfig::instance().num_workers_per_process_java();
break;
default:
RAY_LOG(FATAL) << "The number of workers per process for "
<< Language_Name(entry.first) << " worker is not set.";
if (!RayConfig::instance().enable_multi_tenancy()) {
switch (entry.first) {
case Language::PYTHON:
state.num_workers_per_process =
RayConfig::instance().num_workers_per_process_python();
break;
case Language::JAVA:
state.num_workers_per_process =
RayConfig::instance().num_workers_per_process_java();
break;
default:
RAY_LOG(FATAL) << "The number of workers per process for "
<< Language_Name(entry.first) << " worker is not set.";
}
RAY_CHECK(state.num_workers_per_process > 0)
<< "Number of workers per process of language " << Language_Name(entry.first)
<< " must be positive.";
state.multiple_for_warning =
std::max(state.num_workers_per_process,
std::max(num_workers, maximum_startup_concurrency));
} else {
state.multiple_for_warning = maximum_startup_concurrency;
}
RAY_CHECK(state.num_workers_per_process > 0)
<< "Number of workers per process of language " << Language_Name(entry.first)
<< " must be positive.";
state.multiple_for_warning =
std::max(state.num_workers_per_process,
std::max(num_workers, maximum_startup_concurrency));
// Set worker command for this language.
state.worker_command = entry.second;
RAY_CHECK(!state.worker_command.empty()) << "Worker command must not be empty.";
@@ -111,16 +119,19 @@ WorkerPool::WorkerPool(boost::asio::io_service &io_service, int num_workers,
free_ports_->push(port);
}
}
Start(num_workers);
if (!RayConfig::instance().enable_multi_tenancy()) {
Start(num_workers);
}
}
void WorkerPool::Start(int num_workers) {
RAY_CHECK(!RayConfig::instance().enable_multi_tenancy());
for (auto &entry : states_by_lang_) {
auto &state = entry.second;
int num_worker_processes = static_cast<int>(
std::ceil(static_cast<double>(num_workers) / state.num_workers_per_process));
for (int i = 0; i < num_worker_processes; i++) {
StartWorkerProcess(entry.first);
StartWorkerProcess(entry.first, JobID::Nil());
}
}
}
@@ -154,8 +165,20 @@ uint32_t WorkerPool::Size(const Language &language) const {
}
}
Process WorkerPool::StartWorkerProcess(const Language &language,
const std::vector<std::string> &dynamic_options) {
Process WorkerPool::StartWorkerProcess(const Language &language, const JobID &job_id,
std::vector<std::string> dynamic_options) {
rpc::JobConfig *job_config = nullptr;
if (RayConfig::instance().enable_multi_tenancy()) {
RAY_CHECK(!job_id.IsNil());
auto it = unfinished_jobs_.find(job_id);
if (it == unfinished_jobs_.end()) {
RAY_LOG(DEBUG) << "Job config of job " << job_id << " are not local yet.";
// Will reschedule ready tasks in `NodeManager::HandleJobStarted`.
return Process();
}
job_config = &it->second;
}
auto &state = GetStateForLanguage(language);
// If we are already starting up too many workers, then return without starting
// more.
@@ -175,11 +198,21 @@ Process WorkerPool::StartWorkerProcess(const Language &language,
<< state.idle_actor.size() << " actor workers, and " << state.idle.size()
<< " non-actor workers";
int workers_to_start;
int workers_to_start = 1;
if (dynamic_options.empty()) {
workers_to_start = state.num_workers_per_process;
} else {
workers_to_start = 1;
if (!RayConfig::instance().enable_multi_tenancy()) {
workers_to_start = state.num_workers_per_process;
} else if (language == Language::JAVA) {
workers_to_start = job_config->num_java_workers_per_process();
}
}
if (RayConfig::instance().enable_multi_tenancy() &&
!job_config->jvm_options().empty()) {
// Note that we push the item to the front of the vector to make
// sure this is the freshest option than others.
dynamic_options.insert(dynamic_options.begin(), job_config->jvm_options().begin(),
job_config->jvm_options().end());
}
// Extract pointers from the worker command to pass into execvp.
@@ -216,12 +249,17 @@ Process WorkerPool::StartWorkerProcess(const Language &language,
arg.append(entry.second);
worker_command_args.push_back(arg);
}
// The value of `num_workers_per_process_java` may change depends on whether
// dynamic options is empty, so we can't use the value in `RayConfig`. We always
// overwrite the value here.
worker_command_args.push_back(
"-Dray.raylet.config.num_workers_per_process_java=" +
std::to_string(workers_to_start));
if (!RayConfig::instance().enable_multi_tenancy()) {
// The value of `num_workers_per_process_java` may change depends on whether
// dynamic options is empty, so we can't use the value in `RayConfig`. We always
// overwrite the value here.
worker_command_args.push_back(
"-Dray.raylet.config.num_workers_per_process_java=" +
std::to_string(workers_to_start));
} else {
worker_command_args.push_back("-Dray.job.num-java-workers-per-process=" +
std::to_string(workers_to_start));
}
break;
default:
RAY_LOG(FATAL)
@@ -241,7 +279,14 @@ Process WorkerPool::StartWorkerProcess(const Language &language,
<< " placeholder is not found in worker command.";
}
// TODO(kfstorm): Set up environment variables in a later PR.
Process proc = StartProcess(worker_command_args);
if (RayConfig::instance().enable_multi_tenancy()) {
// If the pid is reused between processes, the old process must have exited.
// So it's safe to bind the pid with another job ID.
RAY_LOG(DEBUG) << "Worker process " << proc.GetId() << " is bound to job " << job_id;
state.worker_pids_to_assigned_jobs[proc.GetId()] = job_id;
}
RAY_LOG(DEBUG) << "Started worker process of " << workers_to_start
<< " worker(s) with pid " << proc.GetId();
MonitorStartingWorkerProcess(proc, language);
@@ -327,34 +372,120 @@ void WorkerPool::MarkPortAsFree(int port) {
}
}
void WorkerPool::HandleJobStarted(const JobID &job_id, const rpc::JobConfig &job_config) {
unfinished_jobs_[job_id] = job_config;
}
void WorkerPool::HandleJobFinished(const JobID &job_id) {
unfinished_jobs_.erase(job_id);
}
Status WorkerPool::RegisterWorker(const std::shared_ptr<WorkerInterface> &worker,
pid_t pid, int *port) {
pid_t pid,
std::function<void(int)> send_reply_callback) {
RAY_CHECK(worker);
// The port that this worker's gRPC server should listen on. 0 if the worker
// should bind on a random port.
int port;
Status status;
auto &state = GetStateForLanguage(worker->GetLanguage());
auto it = state.starting_worker_processes.find(Process::FromPid(pid));
if (it == state.starting_worker_processes.end()) {
RAY_LOG(WARNING) << "Received a register request from an unknown worker " << pid;
return Status::Invalid("Unknown worker");
}
RAY_RETURN_NOT_OK(GetNextFreePort(port));
RAY_LOG(DEBUG) << "Registering worker with pid " << pid << ", port: " << *port;
worker->SetAssignedPort(*port);
worker->SetProcess(it->first);
it->second--;
if (it->second == 0) {
state.starting_worker_processes.erase(it);
// Return -1 to signal to the worker that registration failed.
port = -1;
status = Status::Invalid("Unknown worker");
} else {
RAY_RETURN_NOT_OK(GetNextFreePort(&port));
RAY_LOG(DEBUG) << "Registering worker with pid " << pid << ", port: " << port;
worker->SetAssignedPort(port);
worker->SetProcess(it->first);
it->second--;
if (it->second == 0) {
state.starting_worker_processes.erase(it);
}
RAY_CHECK(worker->GetProcess().GetId() == pid);
state.registered_workers.insert(worker);
if (RayConfig::instance().enable_multi_tenancy()) {
auto dedicated_workers_it = state.worker_pids_to_assigned_jobs.find(pid);
RAY_CHECK(dedicated_workers_it != state.worker_pids_to_assigned_jobs.end());
auto job_id = dedicated_workers_it->second;
worker->AssignJobId(job_id);
// We don't call state.worker_pids_to_assigned_jobs.erase(job_id) here
// because we allow multi-workers per worker process.
// This is a workaround to finish driver registration after all initial workers are
// registered to Raylet if and only if Raylet is started by a Python driver and the
// job config is not set in `ray.init(...)`.
if (first_job_ == job_id && worker->GetLanguage() == Language::PYTHON) {
if (++first_job_registered_python_worker_count_ ==
first_job_driver_wait_num_python_workers_) {
if (first_job_send_register_client_reply_to_driver_) {
first_job_send_register_client_reply_to_driver_();
first_job_send_register_client_reply_to_driver_ = nullptr;
}
}
}
}
status = Status::OK();
}
state.registered_workers.emplace(std::move(worker));
return Status::OK();
// Send the reply immediately for worker registrations.
if (send_reply_callback) {
send_reply_callback(port);
}
return status;
}
Status WorkerPool::RegisterDriver(const std::shared_ptr<WorkerInterface> &driver,
int *port) {
const JobID &job_id, const rpc::JobConfig &job_config,
std::function<void(int)> send_reply_callback) {
int port;
RAY_CHECK(!driver->GetAssignedTaskId().IsNil());
RAY_RETURN_NOT_OK(GetNextFreePort(port));
driver->SetAssignedPort(*port);
RAY_RETURN_NOT_OK(GetNextFreePort(&port));
driver->SetAssignedPort(port);
auto &state = GetStateForLanguage(driver->GetLanguage());
state.registered_drivers.insert(std::move(driver));
driver->AssignJobId(job_id);
unfinished_jobs_[job_id] = job_config;
if (send_reply_callback) {
// This is a workaround to start initial workers on this node if and only if Raylet is
// started by a Python driver and the job config is not set in `ray.init(...)`.
// Invoke the `send_reply_callback` later to only finish driver
// registration after all initial workers are registered to Raylet.
bool delay_callback = false;
// Multi-tenancy is enabled.
if (RayConfig().instance().enable_multi_tenancy()) {
// If this is the first job.
if (first_job_.IsNil()) {
first_job_ = job_id;
// If the number of Python workers we need to wait is positive.
if (num_initial_python_workers_for_first_job_ > 0) {
delay_callback = true;
// Start initial Python workers for the first job.
for (int i = 0; i < num_initial_python_workers_for_first_job_; i++) {
StartWorkerProcess(Language::PYTHON, job_id);
}
}
}
}
if (delay_callback) {
RAY_CHECK(!first_job_send_register_client_reply_to_driver_);
first_job_send_register_client_reply_to_driver_ = [send_reply_callback, port]() {
send_reply_callback(port);
};
} else {
send_reply_callback(port);
}
}
return Status::OK();
}
@@ -424,8 +555,8 @@ std::shared_ptr<WorkerInterface> WorkerPool::PopWorker(
} else if (!HasPendingWorkerForTask(task_spec.GetLanguage(), task_spec.TaskId())) {
// We are not pending a registration from a worker for this task,
// so start a new worker process for this task.
proc =
StartWorkerProcess(task_spec.GetLanguage(), task_spec.DynamicWorkerOptions());
proc = StartWorkerProcess(task_spec.GetLanguage(), task_spec.JobId(),
task_spec.DynamicWorkerOptions());
if (proc.IsValid()) {
state.dedicated_workers_to_tasks[proc] = task_spec.TaskId();
state.tasks_to_dedicated_workers[task_spec.TaskId()] = proc;
@@ -433,13 +564,30 @@ std::shared_ptr<WorkerInterface> WorkerPool::PopWorker(
}
} else if (!task_spec.IsActorTask()) {
// Code path of normal task or actor creation task without dynamic worker options.
if (!state.idle.empty()) {
worker = std::move(*state.idle.begin());
state.idle.erase(state.idle.begin());
if (!RayConfig::instance().enable_multi_tenancy()) {
if (!state.idle.empty()) {
worker = std::move(*state.idle.begin());
state.idle.erase(state.idle.begin());
} else {
// There are no more non-actor workers available to execute this task.
// Start a new worker process.
proc = StartWorkerProcess(task_spec.GetLanguage(), JobID::Nil());
}
} else {
// There are no more non-actor workers available to execute this task.
// Start a new worker process.
proc = StartWorkerProcess(task_spec.GetLanguage());
// Find an available worker which is already assigned to this job.
for (auto it = state.idle.begin(); it != state.idle.end(); it++) {
if ((*it)->GetAssignedJobId() != task_spec.JobId()) {
continue;
}
worker = std::move(*it);
state.idle.erase(it);
break;
}
if (worker == nullptr) {
// There are no more non-actor workers available to execute this task.
// Start a new worker process.
proc = StartWorkerProcess(task_spec.GetLanguage(), task_spec.JobId());
}
}
} else {
// Code path of actor task.
@@ -455,6 +603,9 @@ std::shared_ptr<WorkerInterface> WorkerPool::PopWorker(
WarnAboutSize();
}
if (RayConfig::instance().enable_multi_tenancy() && worker) {
RAY_CHECK(worker->GetAssignedJobId() == task_spec.JobId());
}
return worker;
}
+50 -7
View File
@@ -73,6 +73,8 @@ class WorkerPool : public WorkerPoolInterface {
/// the pool.
///
/// \param num_workers The number of workers to start, per language.
/// \param num_initial_python_workers_for_first_job The number of initial Python
/// workers for the first job.
/// \param maximum_startup_concurrency The maximum number of worker processes
/// that can be started in parallel (typically this should be set to the number of CPU
/// resources on the machine).
@@ -86,6 +88,7 @@ class WorkerPool : public WorkerPoolInterface {
/// \param starting_worker_timeout_callback The callback that will be triggered once
/// it times out to start a worker.
WorkerPool(boost::asio::io_service &io_service, int num_workers,
int num_initial_python_workers_for_first_job,
int maximum_startup_concurrency, int min_worker_port, int max_worker_port,
std::shared_ptr<gcs::GcsClient> gcs_client,
const WorkerCommandMap &worker_commands,
@@ -95,24 +98,42 @@ class WorkerPool : public WorkerPoolInterface {
/// Destructor responsible for freeing a set of workers owned by this class.
virtual ~WorkerPool();
/// Handles the event that a job is started.
///
/// \param job_id ID of the started job.
/// \param job_config The config of the started job.
/// \return Void
void HandleJobStarted(const JobID &job_id, const rpc::JobConfig &job_config);
/// Handles the event that a job is finished.
///
/// \param job_id ID of the finished job.
/// \return Void.
void HandleJobFinished(const JobID &job_id);
/// Register a new worker. The Worker should be added by the caller to the
/// pool after it becomes idle (e.g., requests a work assignment).
///
/// \param[in] worker The worker to be registered.
/// \param[in] pid The PID of the worker.
/// \param[out] port The port that this worker's gRPC server should listen on.
/// \param[in] send_reply_callback The callback to invoke after registration is
/// finished/failed.
/// Returns 0 if the worker should bind on a random port.
/// \return If the registration is successful.
Status RegisterWorker(const std::shared_ptr<WorkerInterface> &worker, pid_t pid,
int *port);
std::function<void(int)> send_reply_callback);
/// Register a new driver.
///
/// \param[in] worker The driver to be registered.
/// \param[out] port The port that this driver's gRPC server should listen on.
/// Returns 0 if the driver should bind on a random port.
/// \param[in] job_id The job ID of the driver.
/// \param[in] job_config The config of the job.
/// \param[in] send_reply_callback The callback to invoke after registration is
/// finished/failed.
/// \return If the registration is successful.
Status RegisterDriver(const std::shared_ptr<WorkerInterface> &worker, int *port);
Status RegisterDriver(const std::shared_ptr<WorkerInterface> &worker,
const JobID &job_id, const rpc::JobConfig &job_config,
std::function<void(int)> send_reply_callback);
/// Get the client connection's registered worker.
///
@@ -208,11 +229,12 @@ class WorkerPool : public WorkerPoolInterface {
/// any workers.
///
/// \param language Which language this worker process should be.
/// \param job_id The ID of the job to which the started worker process belongs.
/// \param dynamic_options The dynamic options that we should add for worker command.
/// \return The id of the process that we started if it's positive,
/// otherwise it means we didn't start a process.
Process StartWorkerProcess(const Language &language,
const std::vector<std::string> &dynamic_options = {});
Process StartWorkerProcess(const Language &language, const JobID &job_id,
std::vector<std::string> dynamic_options = {});
/// The implementation of how to start a new worker process with command arguments.
/// The lifetime of the process is tied to that of the returned object,
@@ -251,6 +273,8 @@ class WorkerPool : public WorkerPoolInterface {
std::unordered_map<Process, TaskID> dedicated_workers_to_tasks;
/// A map for speeding up looking up the pending worker for the given task.
std::unordered_map<TaskID, Process> tasks_to_dedicated_workers;
/// A map for looking up the owner JobId by the pid of worker.
std::unordered_map<pid_t, JobID> worker_pids_to_assigned_jobs;
/// We'll push a warning to the user every time a multiple of this many
/// worker processes has been started.
int multiple_for_warning;
@@ -307,6 +331,25 @@ class WorkerPool : public WorkerPoolInterface {
/// The callback that will be triggered once it times out to start a worker.
std::function<void()> starting_worker_timeout_callback_;
FRIEND_TEST(WorkerPoolTest, InitialWorkerProcessCount);
/// The Job ID of the firstly received job.
JobID first_job_;
/// The callback to send RegisterClientReply to the driver of the first job.
std::function<void()> first_job_send_register_client_reply_to_driver_;
/// The number of registered workers of the first job.
int first_job_registered_python_worker_count_;
/// The umber of initial Python workers to wait for the first job before the driver
/// receives RegisterClientReply.
int first_job_driver_wait_num_python_workers_;
/// The number of initial Python workers for the first job.
int num_initial_python_workers_for_first_job_;
/// This map tracks the latest infos of unfinished jobs.
absl::flat_hash_map<JobID, rpc::JobConfig> unfinished_jobs_;
};
} // namespace raylet
+140 -44
View File
@@ -26,21 +26,15 @@ namespace raylet {
int NUM_WORKERS_PER_PROCESS_JAVA = 3;
int MAXIMUM_STARTUP_CONCURRENCY = 5;
JobID JOB_ID = JobID::FromInt(1);
std::vector<Language> LANGUAGES = {Language::PYTHON, Language::JAVA};
class WorkerPoolMock : public WorkerPool {
public:
WorkerPoolMock(boost::asio::io_service &io_service)
: WorkerPoolMock(
io_service,
{{Language::PYTHON, {"dummy_py_worker_command"}},
{Language::JAVA,
{"dummy_java_worker_command", "RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER"}}}) {}
explicit WorkerPoolMock(boost::asio::io_service &io_service,
const WorkerCommandMap &worker_commands)
: WorkerPool(io_service, 0, MAXIMUM_STARTUP_CONCURRENCY, 0, 0, nullptr,
: WorkerPool(io_service, 0, 0, MAXIMUM_STARTUP_CONCURRENCY, 0, 0, nullptr,
worker_commands, {}, []() {}),
last_worker_process_() {
states_by_lang_[ray::Language::JAVA].num_workers_per_process =
@@ -94,14 +88,22 @@ class WorkerPoolMock : public WorkerPool {
std::unordered_map<Process, std::vector<std::string>> worker_commands_by_proc_;
};
class WorkerPoolTest : public ::testing::Test {
class WorkerPoolTest : public ::testing::TestWithParam<bool> {
public:
WorkerPoolTest() : error_message_type_(1), client_call_manager_(io_service_) {
worker_pool_ = std::unique_ptr<WorkerPoolMock>(new WorkerPoolMock(io_service_));
bool enable_multi_tenancy = GetParam();
RayConfig::instance().initialize(
{{"enable_multi_tenancy", std::to_string(enable_multi_tenancy)},
{"num_workers_per_process_java", std::to_string(NUM_WORKERS_PER_PROCESS_JAVA)}});
SetWorkerCommands(
{{Language::PYTHON, {"dummy_py_worker_command"}},
{Language::JAVA,
{"dummy_java_worker_command", "RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER"}}});
}
std::shared_ptr<WorkerInterface> CreateWorker(
Process proc, const Language &language = Language::PYTHON) {
Process proc, const Language &language = Language::PYTHON,
const JobID &job_id = JOB_ID) {
std::function<void(ClientConnection &)> client_handler =
[this](ClientConnection &client) { HandleNewClient(client); };
std::function<void(std::shared_ptr<ClientConnection>, int64_t,
@@ -119,15 +121,28 @@ class WorkerPoolTest : public ::testing::Test {
WorkerID::FromRandom(), language, "127.0.0.1", client, client_call_manager_);
std::shared_ptr<WorkerInterface> worker =
std::dynamic_pointer_cast<WorkerInterface>(worker_);
worker->AssignJobId(job_id);
if (!proc.IsNull()) {
worker->SetProcess(proc);
}
return worker;
}
std::shared_ptr<WorkerInterface> RegisterDriver(
const Language &language = Language::PYTHON, const JobID &job_id = JOB_ID,
const rpc::JobConfig &job_config = rpc::JobConfig()) {
auto driver = CreateWorker(Process::CreateNewDummy(), Language::PYTHON, job_id);
driver->AssignTaskId(TaskID::ForDriverTask(job_id));
RAY_CHECK_OK(worker_pool_->RegisterDriver(driver, job_id, job_config, nullptr));
return driver;
}
void SetWorkerCommands(const WorkerCommandMap &worker_commands) {
worker_pool_ =
std::unique_ptr<WorkerPoolMock>(new WorkerPoolMock(io_service_, worker_commands));
rpc::JobConfig job_config;
job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA);
RegisterDriver(Language::PYTHON, JOB_ID, job_config);
}
void TestStartupWorkerProcessCount(Language language, int num_workers_per_process,
@@ -139,7 +154,7 @@ class WorkerPoolTest : public ::testing::Test {
static_cast<int>(desired_initial_worker_process_count));
Process last_started_worker_process;
for (int i = 0; i < desired_initial_worker_process_count; i++) {
worker_pool_->StartWorkerProcess(language);
worker_pool_->StartWorkerProcess(language, JOB_ID);
ASSERT_TRUE(worker_pool_->NumWorkerProcessesStarting() <=
expected_worker_process_count);
Process prev = worker_pool_->LastStartedWorkerProcess();
@@ -172,9 +187,10 @@ class WorkerPoolTest : public ::testing::Test {
static inline TaskSpecification ExampleTaskSpec(
const ActorID actor_id = ActorID::Nil(), const Language &language = Language::PYTHON,
const ActorID actor_creation_id = ActorID::Nil(),
const JobID &job_id = JOB_ID, const ActorID actor_creation_id = ActorID::Nil(),
const std::vector<std::string> &dynamic_worker_options = {}) {
rpc::TaskSpec message;
message.set_job_id(job_id.Binary());
message.set_language(language);
if (!actor_id.IsNil()) {
message.set_type(TaskType::ACTOR_TASK);
@@ -191,7 +207,17 @@ static inline TaskSpecification ExampleTaskSpec(
return TaskSpecification(std::move(message));
}
TEST_F(WorkerPoolTest, CompareWorkerProcessObjects) {
static inline std::string GetNumJavaWorkersPerProcessSystemProperty(int num) {
std::string key;
if (RayConfig::instance().enable_multi_tenancy()) {
key = "ray.job.num-java-workers-per-process";
} else {
key = "ray.raylet.config.num_workers_per_process_java";
}
return std::string("-D") + key + "=" + std::to_string(num);
}
TEST_P(WorkerPoolTest, CompareWorkerProcessObjects) {
typedef Process T;
T a(T::CreateNewDummy()), b(T::CreateNewDummy()), empty = T();
ASSERT_TRUE(empty.IsNull());
@@ -205,8 +231,8 @@ TEST_F(WorkerPoolTest, CompareWorkerProcessObjects) {
ASSERT_TRUE(!std::equal_to<T>()(a, empty));
}
TEST_F(WorkerPoolTest, HandleWorkerRegistration) {
Process proc = worker_pool_->StartWorkerProcess(Language::JAVA);
TEST_P(WorkerPoolTest, HandleWorkerRegistration) {
Process proc = worker_pool_->StartWorkerProcess(Language::JAVA, JOB_ID);
std::vector<std::shared_ptr<WorkerInterface>> workers;
for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) {
workers.push_back(CreateWorker(Process(), Language::JAVA));
@@ -217,8 +243,7 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) {
ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), 1);
// Check that we cannot lookup the worker before it's registered.
ASSERT_EQ(worker_pool_->GetRegisteredWorker(worker->Connection()), nullptr);
int port;
RAY_CHECK_OK(worker_pool_->RegisterWorker(worker, proc.GetId(), &port));
RAY_CHECK_OK(worker_pool_->RegisterWorker(worker, proc.GetId(), nullptr));
// Check that we can lookup the worker after it's registered.
ASSERT_EQ(worker_pool_->GetRegisteredWorker(worker->Connection()), worker);
}
@@ -231,30 +256,34 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) {
}
}
TEST_F(WorkerPoolTest, StartupPythonWorkerProcessCount) {
TEST_P(WorkerPoolTest, StartupPythonWorkerProcessCount) {
TestStartupWorkerProcessCount(Language::PYTHON, 1, {"dummy_py_worker_command"});
}
TEST_F(WorkerPoolTest, StartupJavaWorkerProcessCount) {
TEST_P(WorkerPoolTest, StartupJavaWorkerProcessCount) {
TestStartupWorkerProcessCount(
Language::JAVA, NUM_WORKERS_PER_PROCESS_JAVA,
{"dummy_java_worker_command",
std::string("-Dray.raylet.config.num_workers_per_process_java=") +
std::to_string(NUM_WORKERS_PER_PROCESS_JAVA)});
GetNumJavaWorkersPerProcessSystemProperty(NUM_WORKERS_PER_PROCESS_JAVA)});
}
TEST_F(WorkerPoolTest, InitialWorkerProcessCount) {
worker_pool_->Start(1);
// Here we try to start only 1 worker for each worker language. But since each Java
// worker process contains exactly NUM_WORKERS_PER_PROCESS_JAVA (3) workers here,
// it's expected to see 3 workers for Java and 1 worker for Python, instead of 1 for
// each worker language.
ASSERT_NE(worker_pool_->NumWorkersStarting(), 1 * LANGUAGES.size());
ASSERT_EQ(worker_pool_->NumWorkersStarting(), 1 + NUM_WORKERS_PER_PROCESS_JAVA);
ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), LANGUAGES.size());
TEST_P(WorkerPoolTest, InitialWorkerProcessCount) {
if (!RayConfig::instance().enable_multi_tenancy()) {
worker_pool_->Start(1);
// Here we try to start only 1 worker for each worker language. But since each Java
// worker process contains exactly NUM_WORKERS_PER_PROCESS_JAVA (3) workers here,
// it's expected to see 3 workers for Java and 1 worker for Python, instead of 1 for
// each worker language.
ASSERT_NE(worker_pool_->NumWorkersStarting(), 1 * LANGUAGES.size());
ASSERT_EQ(worker_pool_->NumWorkersStarting(), 1 + NUM_WORKERS_PER_PROCESS_JAVA);
ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), LANGUAGES.size());
} else {
ASSERT_EQ(worker_pool_->NumWorkersStarting(), 0);
ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), 0);
}
}
TEST_F(WorkerPoolTest, HandleWorkerPushPop) {
TEST_P(WorkerPoolTest, HandleWorkerPushPop) {
// Try to pop a worker from the empty pool and make sure we don't get one.
std::shared_ptr<WorkerInterface> popped_worker;
const auto task_spec = ExampleTaskSpec();
@@ -281,7 +310,7 @@ TEST_F(WorkerPoolTest, HandleWorkerPushPop) {
ASSERT_EQ(popped_worker, nullptr);
}
TEST_F(WorkerPoolTest, PopActorWorker) {
TEST_P(WorkerPoolTest, PopActorWorker) {
// Create a worker.
auto worker = CreateWorker(Process::CreateNewDummy());
// Add the worker to the pool.
@@ -290,8 +319,7 @@ TEST_F(WorkerPoolTest, PopActorWorker) {
// Assign an actor ID to the worker.
const auto task_spec = ExampleTaskSpec();
auto actor = worker_pool_->PopWorker(task_spec);
const auto job_id = JobID::FromInt(1);
auto actor_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1);
auto actor_id = ActorID::Of(JOB_ID, TaskID::ForDriverTask(JOB_ID), 1);
actor->AssignActorId(actor_id);
worker_pool_->PushWorker(actor);
@@ -304,7 +332,7 @@ TEST_F(WorkerPoolTest, PopActorWorker) {
ASSERT_EQ(actor->GetActorId(), actor_id);
}
TEST_F(WorkerPoolTest, PopWorkersOfMultipleLanguages) {
TEST_P(WorkerPoolTest, PopWorkersOfMultipleLanguages) {
// Create a Python Worker, and add it to the pool
auto py_worker = CreateWorker(Process::CreateNewDummy(), Language::PYTHON);
worker_pool_->PushWorker(py_worker);
@@ -322,26 +350,94 @@ TEST_F(WorkerPoolTest, PopWorkersOfMultipleLanguages) {
ASSERT_NE(worker_pool_->PopWorker(java_task_spec), nullptr);
}
TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) {
TEST_P(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) {
const std::vector<std::string> java_worker_command = {
"RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_0", "dummy_java_worker_command",
"RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER", "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_1"};
SetWorkerCommands({{Language::PYTHON, {"dummy_py_worker_command"}},
{Language::JAVA, java_worker_command}});
const auto job_id = JobID::FromInt(1);
TaskSpecification task_spec = ExampleTaskSpec(
ActorID::Nil(), Language::JAVA,
ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1), {"test_op_0", "test_op_1"});
worker_pool_->StartWorkerProcess(Language::JAVA, task_spec.DynamicWorkerOptions());
ActorID::Nil(), Language::JAVA, JOB_ID,
ActorID::Of(JOB_ID, TaskID::ForDriverTask(JOB_ID), 1), {"test_op_0", "test_op_1"});
worker_pool_->StartWorkerProcess(Language::JAVA, JOB_ID,
task_spec.DynamicWorkerOptions());
const auto real_command =
worker_pool_->GetWorkerCommand(worker_pool_->LastStartedWorkerProcess());
ASSERT_EQ(real_command,
std::vector<std::string>(
{"test_op_0", "dummy_java_worker_command",
"-Dray.raylet.config.num_workers_per_process_java=1", "test_op_1"}));
std::vector<std::string>({"test_op_0", "dummy_java_worker_command",
GetNumJavaWorkersPerProcessSystemProperty(1),
"test_op_1"}));
}
TEST_P(WorkerPoolTest, PopWorkerMultiTenancy) {
if (!RayConfig::instance().enable_multi_tenancy()) {
return;
}
auto job_id1 = JOB_ID;
auto job_id2 = JobID::FromInt(2);
ASSERT_NE(job_id1, job_id2);
JobID job_ids[] = {job_id1, job_id2};
// The driver of job 1 is already registered. Here we register the driver for job 2.
RegisterDriver(Language::PYTHON, job_id2);
// Register 2 workers for each job.
for (auto job_id : job_ids) {
for (int i = 0; i < 2; i++) {
auto worker = CreateWorker(Process::CreateNewDummy(), Language::PYTHON, job_id);
worker_pool_->PushWorker(worker);
}
}
std::unordered_set<WorkerID> worker_ids;
for (int round = 0; round < 2; round++) {
std::vector<std::shared_ptr<WorkerInterface>> workers;
// Pop workers for actor (creation) tasks.
for (auto job_id : job_ids) {
auto actor_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1);
// For the first round, we pop for actor creation tasks.
// For the second round, we pop for actor tasks.
auto task_spec =
ExampleTaskSpec(round == 0 ? ActorID::Nil() : actor_id, Language::PYTHON,
job_id, round == 0 ? actor_id : ActorID::Nil());
auto worker = worker_pool_->PopWorker(task_spec);
ASSERT_TRUE(worker);
ASSERT_EQ(worker->GetAssignedJobId(), job_id);
if (round == 0) {
worker->AssignActorId(actor_id);
}
workers.push_back(worker);
}
// Pop workers for normal tasks.
for (auto job_id : job_ids) {
auto task_spec = ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, job_id);
auto worker = worker_pool_->PopWorker(task_spec);
ASSERT_TRUE(worker);
ASSERT_EQ(worker->GetAssignedJobId(), job_id);
workers.push_back(worker);
}
// Return all workers.
for (auto worker : workers) {
worker_pool_->PushWorker(worker);
if (round == 0) {
// For the first round, all workers are new.
ASSERT_TRUE(worker_ids.insert(worker->WorkerId()).second);
} else {
// For the second round, all workers are existing ones.
ASSERT_TRUE(worker_ids.count(worker->WorkerId()) > 0);
}
}
}
}
INSTANTIATE_TEST_CASE_P(WorkerPoolMultiTenancyTest, WorkerPoolTest,
::testing::Values(true, false));
} // namespace raylet
} // namespace ray
+7 -3
View File
@@ -83,8 +83,12 @@ raylet::RayletClient::RayletClient(
const std::string &raylet_socket, const WorkerID &worker_id, bool is_worker,
const JobID &job_id, const Language &language, const std::string &ip_address,
ClientID *raylet_id, int *port,
std::unordered_map<std::string, std::string> *internal_config)
: grpc_client_(std::move(grpc_client)), worker_id_(worker_id), job_id_(job_id) {
std::unordered_map<std::string, std::string> *internal_config,
const std::string &job_config)
: grpc_client_(std::move(grpc_client)),
worker_id_(worker_id),
job_id_(job_id),
job_config_(job_config) {
// For C++14, we could use std::make_unique
conn_ = std::unique_ptr<raylet::RayletConnection>(
new raylet::RayletConnection(io_service, raylet_socket, -1, -1));
@@ -92,7 +96,7 @@ raylet::RayletClient::RayletClient(
flatbuffers::FlatBufferBuilder fbb;
auto message = protocol::CreateRegisterClientRequest(
fbb, is_worker, to_flatbuf(fbb, worker_id), getpid(), to_flatbuf(fbb, job_id),
language, fbb.CreateString(ip_address));
language, fbb.CreateString(ip_address), /*port=*/0, fbb.CreateString(job_config_));
fbb.Finish(message);
// Register the process ID with the raylet.
// NOTE(swang): If raylet exits and we are registered as a worker, we will get killed.
+4 -1
View File
@@ -177,7 +177,8 @@ class RayletClient : public PinObjectsInterface,
const std::string &raylet_socket, const WorkerID &worker_id,
bool is_worker, const JobID &job_id, const Language &language,
const std::string &ip_address, ClientID *raylet_id, int *port,
std::unordered_map<std::string, std::string> *internal_config);
std::unordered_map<std::string, std::string> *internal_config,
const std::string &job_config);
/// Connect to the raylet via grpc only.
///
@@ -369,6 +370,8 @@ class RayletClient : public PinObjectsInterface,
std::shared_ptr<ray::rpc::NodeManagerWorkerClient> grpc_client_;
const WorkerID worker_id_;
const JobID job_id_;
const std::string job_config_;
/// A map from resource name to the resource IDs that are currently reserved
/// for this worker. Each pair consists of the resource ID and the fraction
/// of that resource allocated for this worker.