From 27cd323ce1ccfd102fd401c268c8183bfa3516d1 Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Tue, 4 Aug 2020 15:51:29 +0800 Subject: [PATCH] [Core] Multi-tenancy: Job isolation & implement per job config (except for env variables) (#9500) --- .../java/io/ray/runtime/RayNativeRuntime.java | 16 +- .../java/io/ray/runtime/config/RayConfig.java | 34 ++- .../src/main/resources/ray.default.conf | 9 + .../main/java/io/ray/test/JobConfigTest.java | 69 +++++ .../java/io/ray/test/MultiDriverTest.java | 130 +++++++++ .../java/io/ray/test/RayletConfigTest.java | 4 +- python/ray/_raylet.pyx | 4 +- python/ray/gcs_utils.py | 2 + python/ray/includes/libcoreworker.pxd | 1 + python/ray/job_config.py | 38 +++ python/ray/node.py | 4 +- python/ray/parameter.py | 5 + python/ray/services.py | 6 +- python/ray/tests/BUILD | 8 + python/ray/tests/test_multi_tenancy.py | 116 ++++++++ python/ray/worker.py | 16 +- src/ray/common/ray_config_def.h | 3 + src/ray/core_worker/core_worker.cc | 3 +- src/ray/core_worker/core_worker.h | 2 + .../java/io_ray_runtime_RayNativeRuntime.cc | 6 +- .../java/io_ray_runtime_RayNativeRuntime.h | 2 +- src/ray/core_worker/lib/java/jni_utils.h | 10 +- src/ray/gcs/pb_util.h | 4 +- src/ray/protobuf/gcs.proto | 18 +- src/ray/raylet/format/node_manager.fbs | 4 +- src/ray/raylet/main.cc | 6 + src/ray/raylet/node_manager.cc | 88 +++--- src/ray/raylet/node_manager.h | 3 + .../raylet/scheduling/cluster_task_manager.cc | 4 +- src/ray/raylet/worker.cc | 15 +- src/ray/raylet/worker_pool.cc | 267 ++++++++++++++---- src/ray/raylet/worker_pool.h | 57 +++- src/ray/raylet/worker_pool_test.cc | 184 +++++++++--- src/ray/raylet_client/raylet_client.cc | 10 +- src/ray/raylet_client/raylet_client.h | 5 +- 35 files changed, 969 insertions(+), 184 deletions(-) create mode 100644 java/test/src/main/java/io/ray/test/JobConfigTest.java create mode 100644 java/test/src/main/java/io/ray/test/MultiDriverTest.java create mode 100644 python/ray/job_config.py create mode 100644 python/ray/tests/test_multi_tenancy.py diff --git a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java index 47eda035c..6d326bc35 100644 --- a/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/RayNativeRuntime.java @@ -11,6 +11,7 @@ import io.ray.runtime.gcs.GcsClient; import io.ray.runtime.gcs.GcsClientOptions; import io.ray.runtime.gcs.RedisClient; import io.ray.runtime.generated.Common.WorkerType; +import io.ray.runtime.generated.Gcs.JobConfig; import io.ray.runtime.object.NativeObjectStore; import io.ray.runtime.runner.RunManager; import io.ray.runtime.task.NativeTaskExecutor; @@ -106,6 +107,17 @@ public final class RayNativeRuntime extends AbstractRayRuntime { } int numWorkersPerProcess = rayConfig.workerMode == WorkerType.DRIVER ? 1 : rayConfig.numWorkersPerProcess; + + byte[] serializedJobConfig = null; + if (rayConfig.workerMode == WorkerType.DRIVER) { + JobConfig.Builder jobConfigBuilder = + JobConfig.newBuilder() + .setNumJavaWorkersPerProcess(rayConfig.numWorkersPerProcess) + .addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker) + .putAllWorkerEnv(rayConfig.workerEnv); + serializedJobConfig = jobConfigBuilder.build().toByteArray(); + } + // TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis. nativeInitialize(rayConfig.workerMode.getNumber(), rayConfig.nodeIp, rayConfig.getNodeManagerPort(), @@ -113,7 +125,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { rayConfig.objectStoreSocketName, rayConfig.rayletSocketName, (rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(), new GcsClientOptions(rayConfig), numWorkersPerProcess, - rayConfig.logDir, rayConfig.rayletConfigParameters); + rayConfig.logDir, rayConfig.rayletConfigParameters, serializedJobConfig); taskExecutor = new NativeTaskExecutor(this); workerContext = new NativeWorkerContext(); @@ -201,7 +213,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { int workerMode, String ndoeIpAddress, int nodeManagerPort, String driverName, String storeSocket, String rayletSocket, byte[] jobId, GcsClientOptions gcsClientOptions, int numWorkersPerProcess, - String logDir, Map rayletConfigParameters); + String logDir, Map rayletConfigParameters, byte[] serializedJobConfig); private static native void nativeRunTaskExecutor(TaskExecutor taskExecutor); diff --git a/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java index 0f9f9f96b..7432a9e1a 100644 --- a/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java @@ -3,6 +3,7 @@ package io.ray.runtime.config; import com.google.common.base.Preconditions; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.typesafe.config.Config; import com.typesafe.config.ConfigException; import com.typesafe.config.ConfigFactory; @@ -29,6 +30,8 @@ public class RayConfig { public static final String DEFAULT_CONFIG_FILE = "ray.default.conf"; public static final String CUSTOM_CONFIG_FILE = "ray.conf"; + private static int DEFAULT_NUM_JAVA_WORKER_PER_PROCESS = 10; + private static final Random RANDOM = new Random(); private static final DateTimeFormatter DATE_TIME_FORMATTER = @@ -90,6 +93,9 @@ public class RayConfig { public final int numWorkersPerProcess; + public final List jvmOptionsForJavaWorker; + public final Map workerEnv; + private void validate() { if (workerMode == WorkerType.WORKER) { Preconditions.checkArgument(redisAddress != null, @@ -141,6 +147,17 @@ public class RayConfig { this.jobId = JobId.NIL; } + // jvm options for java workers of this job. + jvmOptionsForJavaWorker = config.getStringList("ray.job.jvm-options"); + + ImmutableMap.Builder workerEnvBuilder = ImmutableMap.builder(); + Config workerEnvConfig = config.getConfig("ray.job.worker-env"); + if (workerEnvConfig != null) { + for (Map.Entry entry : workerEnvConfig.entrySet()) { + workerEnvBuilder.put(entry.getKey(), workerEnvConfig.getString(entry.getKey())); + } + } + workerEnv = workerEnvBuilder.build(); updateSessionDir(); // Object store configurations. objectStoreSize = config.getBytes("ray.object-store.size"); @@ -206,7 +223,22 @@ public class RayConfig { jobResourcePath = null; } - numWorkersPerProcess = config.getInt("ray.raylet.config.num_workers_per_process_java"); + boolean enableMultiTenancy = false; + if (config.hasPath("ray.raylet.config.enable_multi_tenancy")) { + enableMultiTenancy = + Boolean.valueOf(config.getString("ray.raylet.config.enable_multi_tenancy")); + } + + if (!enableMultiTenancy) { + numWorkersPerProcess = config.getInt("ray.raylet.config.num_workers_per_process_java"); + } else { + final int localNumWorkersPerProcess = config.getInt("ray.job.num-java-workers-per-process"); + if (localNumWorkersPerProcess <= 0) { + numWorkersPerProcess = DEFAULT_NUM_JAVA_WORKER_PER_PROCESS; + } else { + numWorkersPerProcess = localNumWorkersPerProcess; + } + } // Validate config. validate(); diff --git a/java/runtime/src/main/resources/ray.default.conf b/java/runtime/src/main/resources/ray.default.conf index 79f3e3355..b95bc10d8 100644 --- a/java/runtime/src/main/resources/ray.default.conf +++ b/java/runtime/src/main/resources/ray.default.conf @@ -29,6 +29,15 @@ ray { // executing tasks from different jobs. E.g. if it's set to '/tm/job_resources', // the path for job 123 will be '/tmp/job_resources/123'. resource-path: "" + /// The number of java worker per worker process. + num-java-workers-per-process: 10 + /// The jvm options for java workers of the job. + jvm-options: [] + // Environment variables to be set on worker processes. + worker-env { + // key1 : "value1" + // key2 : "value2" + } } // Configurations about logging. diff --git a/java/test/src/main/java/io/ray/test/JobConfigTest.java b/java/test/src/main/java/io/ray/test/JobConfigTest.java new file mode 100644 index 000000000..ce0dfb0da --- /dev/null +++ b/java/test/src/main/java/io/ray/test/JobConfigTest.java @@ -0,0 +1,69 @@ +package io.ray.test; + +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +@Test(groups = {"cluster"}) +public class JobConfigTest extends BaseTest { + + @BeforeClass + public void setupJobConfig() { + System.setProperty("ray.raylet.config.enable_multi_tenancy", "true"); + System.setProperty("ray.job.num-java-workers-per-process", "3"); + System.setProperty("ray.job.jvm-options.0", "-DX=999"); + } + + @AfterClass + public void tearDownJobConfig() { + System.clearProperty("ray.raylet.config.enable_multi_tenancy"); + System.clearProperty("ray.job.num-java-workers-per-process"); + System.clearProperty("ray.job.jvm-options.0"); + } + + public static String getJvmOptions() { + return System.getProperty("X"); + } + + public static Integer getWorkersNum() { + return TestUtils.getRuntime().getRayConfig().numWorkersPerProcess; + } + + public static class MyActor { + + public Integer getWorkersNum() { + return TestUtils.getRuntime().getRayConfig().numWorkersPerProcess; + } + + public String getJvmOptions() { + return System.getProperty("X"); + } + } + + public void testJvmOptions() { + ObjectRef obj = Ray.task(JobConfigTest::getJvmOptions).remote(); + Assert.assertEquals("999", obj.get()); + } + + public void testNumJavaWorkerPerProcess() { + ObjectRef obj = Ray.task(JobConfigTest::getWorkersNum).remote(); + Assert.assertEquals(3, (int) obj.get()); + } + + + public void testInActor() { + ActorHandle actor = Ray.actor(MyActor::new).remote(); + + // test jvm options. + ObjectRef obj1 = actor.task(MyActor::getJvmOptions).remote(); + Assert.assertEquals("999", obj1.get()); + + // test workers number. + ObjectRef obj2 = actor.task(MyActor::getWorkersNum).remote(); + Assert.assertEquals(3, (int) obj2.get()); + } +} diff --git a/java/test/src/main/java/io/ray/test/MultiDriverTest.java b/java/test/src/main/java/io/ray/test/MultiDriverTest.java new file mode 100644 index 000000000..13009d659 --- /dev/null +++ b/java/test/src/main/java/io/ray/test/MultiDriverTest.java @@ -0,0 +1,130 @@ +package io.ray.test; + +import io.ray.api.ActorHandle; +import io.ray.api.ObjectRef; +import io.ray.api.Ray; +import io.ray.runtime.config.RayConfig; +import io.ray.runtime.util.SystemUtil; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.lang.ProcessBuilder.Redirect; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +@Test(groups = {"cluster"}) +public class MultiDriverTest extends BaseTest { + + private static final int DRIVER_COUNT = 10; + private static final int NORMAL_TASK_COUNT_PER_DRIVER = 100; + private static final int ACTOR_COUNT_PER_DRIVER = 10; + private static final String PID_LIST_PREFIX = "PID: "; + + @BeforeClass + public void setUpClass() { + System.setProperty("ray.raylet.config.enable_multi_tenancy", "true"); + } + + @AfterClass + public void tearDownClass() { + System.clearProperty("ray.raylet.config.enable_multi_tenancy"); + } + + static int getPid() { + return SystemUtil.pid(); + } + + public static class Actor { + + public int getPid() { + return SystemUtil.pid(); + } + } + + public static void main(String[] args) throws IOException { + Ray.init(); + + List> pidObjectList = new ArrayList<>(); + // Submit some normal tasks and get the PIDs of workers which execute the tasks. + for (int i = 0; i < NORMAL_TASK_COUNT_PER_DRIVER; ++i) { + pidObjectList.add(Ray.task(MultiDriverTest::getPid).remote()); + } + // Create some actors and get the PIDs of actors. + for (int i = 0; i < ACTOR_COUNT_PER_DRIVER; ++i) { + ActorHandle actor = Ray.actor(Actor::new).remote(); + pidObjectList.add(actor.task(Actor::getPid).remote()); + } + Set pids = new HashSet<>(); + for (ObjectRef object : pidObjectList) { + pids.add(object.get()); + } + // Write pids to stdout + System.out.println( + PID_LIST_PREFIX + pids.stream().map(String::valueOf).collect(Collectors.joining(","))); + } + + public void testMultiDrivers() throws InterruptedException, IOException { + // This test case starts some driver processes. Each driver process submits some tasks and + // collect the PIDs of the workers used by the driver. The drivers output the PID list + // which will be read by the test case itself. The test case will compare the PIDs used by + // different drivers and make sure that all the PIDs don't overlap. If overlapped, it means that + // tasks owned by different drivers were scheduled to the same worker process, that is, tasks + // of different jobs were not correctly isolated during execution. + List drivers = new ArrayList<>(); + for (int i = 0; i < DRIVER_COUNT; ++i) { + drivers.add(startDriver()); + } + + // Wait for drivers to finish. + for (Process driver : drivers) { + driver.waitFor(); + Assert.assertEquals(driver.exitValue(), 0, + "The driver exited with code " + driver.exitValue()); + } + + // Read driver outputs and check for any PID overlap. + Set pids = new HashSet<>(); + for (Process driver : drivers) { + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(driver.getInputStream()))) { + String line; + int previousSize = pids.size(); + while ((line = reader.readLine()) != null) { + if (line.startsWith(PID_LIST_PREFIX)) { + for (String pidString : line.substring(PID_LIST_PREFIX.length()).split(",")) { + // Make sure the PIDs don't overlap. + Assert.assertTrue(pids.add(Integer.valueOf(pidString)), + "Worker process with PID " + line + " is shared by multiple drivers."); + } + break; + } + } + int nowSize = pids.size(); + Assert.assertTrue(nowSize > previousSize); + } + } + } + + private Process startDriver() throws IOException { + RayConfig rayConfig = TestUtils.getRuntime().getRayConfig(); + + ProcessBuilder builder = new ProcessBuilder( + "java", + "-cp", + System.getProperty("java.class.path"), + "-Dray.redis.address=" + rayConfig.getRedisAddress(), + "-Dray.object-store.socket-name=" + rayConfig.objectStoreSocketName, + "-Dray.raylet.socket-name=" + rayConfig.rayletSocketName, + "-Dray.raylet.node-manager-port=" + String.valueOf(rayConfig.getNodeManagerPort()), + MultiDriverTest.class.getName()); + builder.redirectError(Redirect.INHERIT); + return builder.start(); + } +} diff --git a/java/test/src/main/java/io/ray/test/RayletConfigTest.java b/java/test/src/main/java/io/ray/test/RayletConfigTest.java index 5642014e7..18d87d00f 100644 --- a/java/test/src/main/java/io/ray/test/RayletConfigTest.java +++ b/java/test/src/main/java/io/ray/test/RayletConfigTest.java @@ -9,8 +9,8 @@ import org.testng.annotations.Test; public class RayletConfigTest extends BaseTest { - private static final String RAY_CONFIG_KEY = "num_workers_per_process_java"; - private static final String RAY_CONFIG_VALUE = "2"; + private static final String RAY_CONFIG_KEY = "get_timeout_milliseconds"; + private static final String RAY_CONFIG_VALUE = "1234"; @BeforeClass public void beforeClass() { diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 501259373..b4cf57eb0 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -658,7 +658,8 @@ cdef class CoreWorker: def __cinit__(self, is_driver, store_socket, raylet_socket, JobID job_id, GcsClientOptions gcs_options, log_dir, node_ip_address, node_manager_port, raylet_ip_address, - local_mode, driver_name, stdout_file, stderr_file): + local_mode, driver_name, stdout_file, stderr_file, + serialized_job_config): self.is_driver = is_driver self.is_local_mode = local_mode @@ -688,6 +689,7 @@ cdef class CoreWorker: options.num_workers = 1 options.kill_main = kill_main_task options.terminate_asyncio_thread = terminate_asyncio_thread + options.serialized_job_config = serialized_job_config CCoreWorkerProcess.Initialize(options) diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index 8e9fa979d..7147ca97e 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -3,6 +3,7 @@ from ray.core.generated.gcs_pb2 import ( ActorTableData, GcsNodeInfo, JobTableData, + JobConfig, ErrorTableData, ErrorType, GcsEntry, @@ -25,6 +26,7 @@ __all__ = [ "ActorTableData", "GcsNodeInfo", "JobTableData", + "JobConfig", "ErrorTableData", "ErrorType", "GcsEntry", diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 03e918f25..c495bd0e7 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -226,6 +226,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: (c_bool() nogil) kill_main CCoreWorkerOptions() (void() nogil) terminate_asyncio_thread + c_string serialized_job_config cdef cppclass CCoreWorkerProcess "ray::CoreWorkerProcess": @staticmethod diff --git a/python/ray/job_config.py b/python/ray/job_config.py new file mode 100644 index 000000000..b82160d99 --- /dev/null +++ b/python/ray/job_config.py @@ -0,0 +1,38 @@ +import ray + + +class JobConfig: + """A class used to store the configurations of a job. + + Attributes: + worker_env (dict): Environment variables to be set on worker + processes. + num_java_workers_per_process (int): The number of java workers per + worker process. + jvm_options (str[]): The jvm options for java workers of the job. + """ + + def __init__( + self, + worker_env=None, + num_java_workers_per_process=10, + jvm_options=None, + ): + if worker_env is None: + self.worker_env = dict() + else: + self.worker_env = worker_env + self.num_java_workers_per_process = num_java_workers_per_process + if jvm_options is None: + self.jvm_options = [] + else: + self.jvm_options = jvm_options + + def serialize(self): + job_config = ray.gcs_utils.JobConfig() + for key in self.worker_env: + job_config.worker_env[key] = self.worker_env[key] + job_config.num_java_workers_per_process = ( + self.num_java_workers_per_process) + job_config.jvm_options.extend(self.jvm_options) + return job_config.SerializeToString() diff --git a/python/ray/node.py b/python/ray/node.py index fe1e505e5..f35e2e243 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -679,7 +679,9 @@ class Node: huge_pages=self._ray_params.huge_pages, fate_share=self.kernel_fate_share, socket_to_use=self.socket, - head_node=self.head) + head_node=self.head, + start_initial_python_workers_for_first_job=self._ray_params. + start_initial_python_workers_for_first_job) assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info] diff --git a/python/ray/parameter.py b/python/ray/parameter.py index 7baa8562b..7cc55f56b 100644 --- a/python/ray/parameter.py +++ b/python/ray/parameter.py @@ -94,6 +94,8 @@ class RayParams: lru_evict (bool): Enable LRU eviction if space is needed. enable_object_reconstruction (bool): Enable plasma reconstruction on failure. + start_initial_python_workers_for_first_job (bool): If true, start + initial Python workers for the first job on the node. """ def __init__(self, @@ -136,6 +138,7 @@ class RayParams: include_java=False, java_worker_options=None, load_code_from_local=False, + start_initial_python_workers_for_first_job=False, _internal_config=None, enable_object_reconstruction=False, metrics_agent_port=None, @@ -178,6 +181,8 @@ class RayParams: self.java_worker_options = java_worker_options self.load_code_from_local = load_code_from_local self.metrics_agent_port = metrics_agent_port + self.start_initial_python_workers_for_first_job = ( + start_initial_python_workers_for_first_job) self._internal_config = _internal_config self._lru_evict = lru_evict self._enable_object_reconstruction = enable_object_reconstruction diff --git a/python/ray/services.py b/python/ray/services.py index 0c90fe834..4b40bd282 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1271,7 +1271,8 @@ def start_raylet(redis_address, huge_pages=False, fate_share=None, socket_to_use=None, - head_node=False): + head_node=False, + start_initial_python_workers_for_first_job=False): """Start a raylet, which is a combined local scheduler and object manager. Args: @@ -1403,6 +1404,9 @@ def start_raylet(redis_address, "--session_dir={}".format(session_dir), "--metrics-agent-port={}".format(metrics_agent_port), ] + if start_initial_python_workers_for_first_job: + command.append("--num_initial_python_workers_for_first_job={}".format( + resource_spec.num_cpus)) if config.get("plasma_store_as_thread"): # command related to the plasma store plasma_directory, object_store_memory = determine_plasma_store_config( diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 4de4f950e..81c989ff5 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -111,6 +111,14 @@ py_test( deps = ["//:ray_lib"], ) +py_test( + name = "test_multi_tenancy", + size = "small", + srcs = SRCS + ["test_multi_tenancy.py"], + tags = ["exclusive"], + deps = ["//:ray_lib"], +) + py_test( name = "test_component_failures", size = "small", diff --git a/python/ray/tests/test_multi_tenancy.py b/python/ray/tests/test_multi_tenancy.py new file mode 100644 index 000000000..ade11cc8b --- /dev/null +++ b/python/ray/tests/test_multi_tenancy.py @@ -0,0 +1,116 @@ +# coding: utf-8 +import json +import sys + +import grpc +import pytest + +import ray +import ray.test_utils +from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc +from ray.test_utils import wait_for_condition, run_string_as_driver_nonblocking + + +# Test that when `redis_address` and `job_config` is not set in +# `ray.init(...)`, Raylet will start `num_cpus` Python workers for the driver. +def test_initial_workers(shutdown_only): + # `num_cpus` should be <=2 because a Travis CI machine only has 2 CPU cores + ray.init( + num_cpus=1, + include_dashboard=True, + _internal_config=json.dumps({ + "enable_multi_tenancy": True + })) + raylet = ray.nodes()[0] + raylet_address = "{}:{}".format(raylet["NodeManagerAddress"], + raylet["NodeManagerPort"]) + channel = grpc.insecure_channel(raylet_address) + stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel) + wait_for_condition(lambda: len([ + worker for worker in stub.GetNodeStats( + node_manager_pb2.GetNodeStatsRequest()).workers_stats + if not worker.is_driver + ]) == 1, + timeout=10) + + +# This test case starts some driver processes. Each driver process submits +# some tasks and collect the PIDs of the workers used by the driver. The +# drivers output the PID list which will be read by the test case itself. The +# test case will compare the PIDs used by different drivers and make sure that +# all the PIDs don't overlap. If overlapped, it means that tasks owned by +# different drivers were scheduled to the same worker process, that is, tasks +# of different jobs were not correctly isolated during execution. +def test_multi_drivers(shutdown_only): + info = ray.init( + _internal_config=json.dumps({ + "enable_multi_tenancy": True + })) + + driver_code = """ +import os +import sys +import ray + + +ray.init(address="{}") + +@ray.remote +class Actor: + def get_pid(self): + return os.getpid() + +@ray.remote +def get_pid(): + return os.getpid() + +pid_objs = [] +# Submit some normal tasks and get the PIDs of workers which execute the tasks. +pid_objs = pid_objs + [get_pid.remote() for _ in range(5)] +# Create some actors and get the PIDs of actors. +actors = [Actor.remote() for _ in range(5)] +pid_objs = pid_objs + [actor.get_pid.remote() for actor in actors] + +pids = set([ray.get(obj) for obj in pid_objs]) +# Write pids to stdout +print("PID:" + str.join(",", [str(_) for _ in pids])) + +ray.shutdown() + """.format(info["redis_address"]) + + driver_count = 10 + processes = [ + run_string_as_driver_nonblocking(driver_code) + for _ in range(driver_count) + ] + outputs = [] + for p in processes: + out = p.stdout.read().decode("ascii") + err = p.stderr.read().decode("ascii") + p.wait() + # out, err = p.communicate() + # out = ray.utils.decode(out) + # err = ray.utils.decode(err) + if p.returncode != 0: + print("Driver with PID {} returned error code {}".format( + p.pid, p.returncode)) + print("STDOUT:\n{}".format(out)) + print("STDERR:\n{}".format(err)) + outputs.append((p, out)) + + all_worker_pids = set() + for p, out in outputs: + assert p.returncode == 0 + for line in out.splitlines(): + if line.startswith("PID:"): + worker_pids = [int(_) for _ in line.split(":")[1].split(",")] + assert len(worker_pids) > 0 + for worker_pid in worker_pids: + assert worker_pid not in all_worker_pids, ( + ("Worker process with PID {} is shared" + + " by multiple drivers.").format(worker_pid)) + all_worker_pids.add(worker_pid) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/worker.py b/python/ray/worker.py index 9cf7df3ec..05bc4e22d 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -20,6 +20,7 @@ import ray.cloudpickle as pickle import ray.gcs_utils import ray.memory_monitor as memory_monitor import ray.node +import ray.job_config import ray.parameter import ray.ray_constants as ray_constants import ray.remote_function @@ -486,6 +487,7 @@ def init(address=None, dashboard_host="localhost", dashboard_port=ray_constants.DEFAULT_DASHBOARD_PORT, job_id=None, + job_config=None, configure_logging=True, logging_level=logging.INFO, logging_format=ray_constants.LOGGER_FORMAT, @@ -588,6 +590,7 @@ def init(address=None, dashboard_port: The port to bind the dashboard server to. Defaults to 8265. job_id: The ID of this job. + job_config (ray.job_config.JobConfig): The job configuration. configure_logging: True (default) if configuration of logging is allowed here. Otherwise, the user may want to configure it separately. @@ -713,6 +716,7 @@ def init(address=None, temp_dir=temp_dir, load_code_from_local=load_code_from_local, java_worker_options=java_worker_options, + start_initial_python_workers_for_first_job=True, _internal_config=_internal_config, lru_evict=lru_evict, enable_object_reconstruction=enable_object_reconstruction) @@ -803,7 +807,8 @@ def init(address=None, log_to_driver=log_to_driver, worker=global_worker, driver_object_store_memory=driver_object_store_memory, - job_id=job_id) + job_id=job_id, + job_config=job_config) for hook in _post_init_hooks: hook() @@ -1147,7 +1152,8 @@ def connect(node, log_to_driver=False, worker=global_worker, driver_object_store_memory=None, - job_id=None): + job_id=None, + job_config=None): """Connect this worker to the raylet, to Plasma, and to Redis. Args: @@ -1160,6 +1166,7 @@ def connect(node, driver_object_store_memory: Limit the amount of memory the driver can use in the object store when creating objects. job_id: The ID of job. If it's None, then we will generate one. + job_config (ray.job_config.JobConfig): The job configuration. """ # Do some basic checking to make sure we didn't call ray.init twice. error_message = "Perhaps you called ray.init twice by accident?" @@ -1265,7 +1272,9 @@ def connect(node, int(redis_port), node.redis_password, ) - + if job_config is None: + job_config = ray.job_config.JobConfig() + serialized_job_config = job_config.serialize() worker.core_worker = ray._raylet.CoreWorker( (mode == SCRIPT_MODE or mode == LOCAL_MODE), node.plasma_store_socket_name, @@ -1280,6 +1289,7 @@ def connect(node, driver_name, log_stdout_file_path, log_stderr_file_path, + serialized_job_config, ) # Create an object for interfacing with the global state. diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index dca49754e..ad1667fbf 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -337,3 +337,6 @@ RAY_CONFIG(uint32_t, max_tasks_in_flight_per_worker, 1) /// The maximum number of resource shapes included in the resource /// load reported by each raylet. RAY_CONFIG(int64_t, max_resource_shapes_per_load_report, 100) + +/// Whether to enable multi tenancy features. +RAY_CONFIG(bool, enable_multi_tenancy, false) diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 7b89836d7..57a53bf1b 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -311,7 +311,8 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ io_service_, std::move(grpc_client), options_.raylet_socket, GetWorkerID(), (options_.worker_type == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(), options_.language, options_.node_ip_address, - &local_raylet_id, &assigned_port, &internal_config)); + &local_raylet_id, &assigned_port, &internal_config, + options_.serialized_job_config)); connected_ = true; RAY_CHECK(assigned_port != -1) diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index ad9d5fed0..b4b7151cf 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -119,6 +119,8 @@ struct CoreWorkerOptions { int num_workers; /// The function to destroy asyncio event and loops. std::function terminate_asyncio_thread; + /// Serialized representation of JobConfig. + std::string serialized_job_config; }; /// Lifecycle management of one or more `CoreWorker` instances in a process. diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index f67e8e869..602913fbd 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -93,7 +93,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( JNIEnv *env, jclass, jint workerMode, jstring nodeIpAddress, jint nodeManagerPort, jstring driverName, jstring storeSocket, jstring rayletSocket, jbyteArray jobId, jobject gcsClientOptions, jint numWorkersPerProcess, jstring logDir, - jobject rayletConfigParameters) { + jobject rayletConfigParameters, jbyteArray jobConfig) { auto raylet_config = JavaMapToNativeMap( env, rayletConfigParameters, [](JNIEnv *env, jobject java_key) { @@ -203,6 +203,8 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( } }; + std::string serialized_job_config = + (jobConfig == nullptr ? "" : JavaByteArrayToNativeString(env, jobConfig)); ray::CoreWorkerOptions options = { static_cast(workerMode), // worker_type ray::Language::JAVA, // langauge @@ -228,6 +230,8 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( true, // ref_counting_enabled false, // is_local_mode static_cast(numWorkersPerProcess), // num_workers + nullptr, // terminate_asyncio_thread + serialized_job_config, // serialized_job_config }; ray::CoreWorkerProcess::Initialize(options); diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h index 5f9e101a1..69c05cf93 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h @@ -29,7 +29,7 @@ extern "C" { */ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( JNIEnv *, jclass, jint, jstring, jint, jstring, jstring, jstring, jbyteArray, jobject, - jint, jstring, jobject); + jint, jstring, jobject, jbyteArray); /* * Class: io_ray_runtime_RayNativeRuntime diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index 67207dc60..738ed3cb4 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -241,6 +241,14 @@ class JavaByteArrayBuffer : public ray::Buffer { jbyte *native_bytes_; }; +/// Convert a Java byte array to a C++ string. +inline std::string JavaByteArrayToNativeString(JNIEnv *env, const jbyteArray &bytes) { + const auto size = env->GetArrayLength(bytes); + std::string str(size, 0); + env->GetByteArrayRegion(bytes, 0, size, reinterpret_cast(&str.front())); + return str; +} + /// Convert a Java byte array to a C++ UniqueID. template inline ID JavaByteArrayToId(JNIEnv *env, const jbyteArray &bytes) { @@ -516,4 +524,4 @@ inline std::string GetActorFullName(bool global, std::string name) { return global ? name : ::ray::CoreWorkerProcess::GetCoreWorker().GetCurrentJobId().Hex() + "-" + name; -} +} \ No newline at end of file diff --git a/src/ray/gcs/pb_util.h b/src/ray/gcs/pb_util.h index f9adaa3e3..2cb0ac535 100644 --- a/src/ray/gcs/pb_util.h +++ b/src/ray/gcs/pb_util.h @@ -35,14 +35,14 @@ namespace gcs { inline std::shared_ptr CreateJobTableData( const ray::JobID &job_id, bool is_dead, int64_t timestamp, const std::string &driver_ip_address, int64_t driver_pid, - const ray::rpc::JobConfigs &job_configs = {}) { + const ray::rpc::JobConfig &job_config = {}) { auto job_info_ptr = std::make_shared(); job_info_ptr->set_job_id(job_id.Binary()); job_info_ptr->set_is_dead(is_dead); job_info_ptr->set_timestamp(timestamp); job_info_ptr->set_driver_ip_address(driver_ip_address); job_info_ptr->set_driver_pid(driver_pid); - *job_info_ptr->mutable_configs() = job_configs; + *job_info_ptr->mutable_config() = job_config; return job_info_ptr; } diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index 2f668f32b..35d1fb37a 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -309,19 +309,13 @@ message TaskLeaseData { uint64 timeout = 4; } -message JobConfigs { - // The initial Python workers to start per node. If a negative value is specified, it - // fallbacks to `num_cpus`. - int32 num_initial_python_workers = 1; - // The initial Java workers to start per node. If a negative value is specified, it - // fallbacks to `num_cpus`. - int32 num_initial_java_workers = 2; +message JobConfig { // Environment variables to be set on worker processes. - map worker_env = 3; + map worker_env = 1; // The number of java workers per worker process. - uint32 num_java_workers_per_process = 4; + uint32 num_java_workers_per_process = 2; // The jvm options for java workers of the job. - repeated string jvm_options = 5; + repeated string jvm_options = 3; } message JobTableData { @@ -335,8 +329,8 @@ message JobTableData { string driver_ip_address = 4; // Process ID of the driver running this job. int64 driver_pid = 5; - // The configs of this job. - JobConfigs configs = 6; + // The config of this job. + JobConfig config = 6; } // This table stores the actor checkpoint data. An actor checkpoint diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index 7c7df6d56..9cbdf62e7 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -157,6 +157,8 @@ table RegisterClientRequest { ip_address: string; // Port that this worker is listening on. port: int; + // The config bytes of this job serialized with protobuf. + serialized_job_config: string; } table RegisterClientReply { @@ -311,6 +313,6 @@ table SetResourceRequest { } table SubscribePlasmaReady { - // ObjectID to wait for + // ObjectID to wait for object_id: string; } diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 74e9a3295..0054b2a86 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -36,6 +36,8 @@ DEFINE_int32(min_worker_port, 0, DEFINE_int32(max_worker_port, 0, "The highest port that workers' gRPC servers will bind on."); DEFINE_int32(num_initial_workers, 0, "Number of initial workers."); +DEFINE_int32(num_initial_python_workers_for_first_job, 0, + "Number of initial Python workers for the first job."); DEFINE_int32(maximum_startup_concurrency, 1, "Maximum startup concurrency"); DEFINE_string(static_resource_list, "", "The static resource list of this node."); DEFINE_string(config_list, "", "The raylet config list of this node."); @@ -71,6 +73,8 @@ int main(int argc, char *argv[]) { const int min_worker_port = static_cast(FLAGS_min_worker_port); const int max_worker_port = static_cast(FLAGS_max_worker_port); const int num_initial_workers = static_cast(FLAGS_num_initial_workers); + const int num_initial_python_workers_for_first_job = + static_cast(FLAGS_num_initial_python_workers_for_first_job); const int maximum_startup_concurrency = static_cast(FLAGS_maximum_startup_concurrency); const std::string static_resource_list = FLAGS_static_resource_list; @@ -160,6 +164,8 @@ int main(int argc, char *argv[]) { node_manager_config.node_manager_address = node_ip_address; node_manager_config.node_manager_port = node_manager_port; node_manager_config.num_initial_workers = num_initial_workers; + node_manager_config.num_initial_python_workers_for_first_job = + num_initial_python_workers_for_first_job; node_manager_config.maximum_startup_concurrency = maximum_startup_concurrency; node_manager_config.min_worker_port = min_worker_port; node_manager_config.max_worker_port = max_worker_port; diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 0faee2122..61b96b9ec 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -162,9 +162,11 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, initial_config_(config), local_available_resources_(config.resource_config), worker_pool_( - io_service, config.num_initial_workers, config.maximum_startup_concurrency, - config.min_worker_port, config.max_worker_port, gcs_client_, - config.worker_commands, config.raylet_config, + io_service, config.num_initial_workers, + config.num_initial_python_workers_for_first_job, + config.maximum_startup_concurrency, config.min_worker_port, + config.max_worker_port, gcs_client_, config.worker_commands, + config.raylet_config, /*starting_worker_timeout_callback=*/ [this]() { this->DispatchTasks(this->local_queues_.GetReadyTasksByClass()); }), scheduling_policy_(local_queues_), @@ -347,12 +349,20 @@ void NodeManager::HandleJobStarted(const JobID &job_id, const JobTableData &job_ RAY_LOG(DEBUG) << "HandleJobStarted " << job_id; RAY_CHECK(!job_data.is_dead()); - // TODO(kfstorm): Spawn job initial workers in a later PR. + worker_pool_.HandleJobStarted(job_id, job_data.config()); + if (RayConfig::instance().enable_multi_tenancy()) { + // Tasks of this job may already arrived but failed to pop a worker because the job + // config is not local yet. So we trigger dispatching again here to try to + // reschedule these tasks. + DispatchTasks(local_queues_.GetReadyTasksByClass()); + } } void NodeManager::HandleJobFinished(const JobID &job_id, const JobTableData &job_data) { RAY_LOG(DEBUG) << "HandleJobFinished " << job_id; RAY_CHECK(job_data.is_dead()); + worker_pool_.HandleJobFinished(job_id); + auto workers = worker_pool_.GetWorkersRunningTasksForJob(job_id); // Kill all the workers. The actual cleanup for these workers is done // later when we receive the DisconnectClient message from them. @@ -1218,13 +1228,31 @@ void NodeManager::ProcessRegisterClientRequestMessage( auto worker = std::dynamic_pointer_cast(std::make_shared( worker_id, language, worker_ip_address, client, client_call_manager_)); - int assigned_port; + auto send_reply_callback = [this, client](int assigned_port) { + flatbuffers::FlatBufferBuilder fbb; + std::vector internal_config_keys; + std::vector internal_config_values; + for (auto kv : initial_config_.raylet_config) { + internal_config_keys.push_back(kv.first); + internal_config_values.push_back(kv.second); + } + auto reply = ray::protocol::CreateRegisterClientReply( + fbb, to_flatbuf(fbb, self_node_id_), assigned_port, + string_vec_to_flatbuf(fbb, internal_config_keys), + string_vec_to_flatbuf(fbb, internal_config_values)); + fbb.Finish(reply); + client->WriteMessageAsync( + static_cast(protocol::MessageType::RegisterClientReply), fbb.GetSize(), + fbb.GetBufferPointer(), [this, client](const ray::Status &status) { + if (!status.ok()) { + ProcessDisconnectClientMessage(client); + } + }); + }; + if (message->is_worker()) { // Register the new worker. - if (!worker_pool_.RegisterWorker(worker, pid, &assigned_port).ok()) { - // Return -1 to signal to the worker that registration failed. - assigned_port = -1; - } + RAY_UNUSED(worker_pool_.RegisterWorker(worker, pid, send_reply_callback)); } else { // Register the new driver. RAY_CHECK(pid >= 0); @@ -1233,38 +1261,18 @@ void NodeManager::ProcessRegisterClientRequestMessage( // Compute a dummy driver task id from a given driver. const TaskID driver_task_id = TaskID::ComputeDriverTaskId(worker_id); worker->AssignTaskId(driver_task_id); - worker->AssignJobId(job_id); - Status status = worker_pool_.RegisterDriver(worker, &assigned_port); + rpc::JobConfig job_config; + job_config.ParseFromString(message->serialized_job_config()->str()); + Status status = + worker_pool_.RegisterDriver(worker, job_id, job_config, send_reply_callback); if (status.ok()) { local_queues_.AddDriverTaskId(driver_task_id); - auto job_data_ptr = gcs::CreateJobTableData( - job_id, /*is_dead*/ false, std::time(nullptr), worker_ip_address, pid); + auto job_data_ptr = + gcs::CreateJobTableData(job_id, /*is_dead*/ false, std::time(nullptr), + worker_ip_address, pid, job_config); RAY_CHECK_OK(gcs_client_->Jobs().AsyncAdd(job_data_ptr, nullptr)); - } else { - // Return -1 to signal to the worker that registration failed. - assigned_port = -1; } } - - flatbuffers::FlatBufferBuilder fbb; - std::vector internal_config_keys; - std::vector internal_config_values; - for (auto kv : initial_config_.raylet_config) { - internal_config_keys.push_back(kv.first); - internal_config_values.push_back(kv.second); - } - auto reply = ray::protocol::CreateRegisterClientReply( - fbb, to_flatbuf(fbb, self_node_id_), assigned_port, - string_vec_to_flatbuf(fbb, internal_config_keys), - string_vec_to_flatbuf(fbb, internal_config_values)); - fbb.Finish(reply); - client->WriteMessageAsync( - static_cast(protocol::MessageType::RegisterClientReply), fbb.GetSize(), - fbb.GetBufferPointer(), [this, client](const ray::Status &status) { - if (!status.ok()) { - ProcessDisconnectClientMessage(client); - } - }); } void NodeManager::ProcessAnnounceWorkerPortMessage( @@ -2690,9 +2698,11 @@ bool NodeManager::FinishAssignedTask(WorkerInterface &worker) { task_dependency_manager_.UnsubscribeGetDependencies(spec.TaskId()); task_dependency_manager_.TaskCanceled(task_id); - // Unset the worker's assigned job Id if this is not an actor. - if (!spec.IsActorCreationTask() && !spec.IsActorTask()) { - worker.AssignJobId(JobID::Nil()); + if (!RayConfig::instance().enable_multi_tenancy()) { + // Unset the worker's assigned job Id if this is not an actor. + if (!spec.IsActorCreationTask() && !spec.IsActorTask()) { + worker.AssignJobId(JobID::Nil()); + } } if (!spec.IsActorCreationTask()) { // Unset the worker's assigned task. We keep the assigned task ID for diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 9cd556813..7f4db082d 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -67,6 +67,8 @@ struct NodeManagerConfig { int max_worker_port; /// The initial number of workers to create. int num_initial_workers; + /// Number of initial Python workers for the first job. + int num_initial_python_workers_for_first_job; /// The maximum number of workers that can be started concurrently by a /// worker pool. int maximum_startup_concurrency; @@ -731,6 +733,7 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// The resources (and specific resource IDs) that are currently available. ResourceIdSet local_available_resources_; std::unordered_map cluster_resource_map_; + /// A pool of workers. WorkerPool worker_pool_; /// A set of queues to maintain tasks. diff --git a/src/ray/raylet/scheduling/cluster_task_manager.cc b/src/ray/raylet/scheduling/cluster_task_manager.cc index 47af947d3..af34d1dd5 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager.cc @@ -125,7 +125,9 @@ void ClusterTaskManager::DispatchScheduledTasksToWorkers( worker->SetAllocatedInstances(allocated_instances); } worker->AssignTaskId(spec.TaskId()); - worker->AssignJobId(spec.JobId()); + if (!RayConfig::instance().enable_multi_tenancy()) { + worker->AssignJobId(spec.JobId()); + } worker->SetAssignedTask(task); Dispatch(worker, leased_workers, spec, reply, callback); } diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index aea922f2c..ccd74f036 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -106,7 +106,20 @@ const std::unordered_set &Worker::GetBlockedTaskIds() const { return blocked_task_ids_; } -void Worker::AssignJobId(const JobID &job_id) { assigned_job_id_ = job_id; } +void Worker::AssignJobId(const JobID &job_id) { + if (!RayConfig::instance().enable_multi_tenancy()) { + assigned_job_id_ = job_id; + } else { + if (!assigned_job_id_.IsNil()) { + RAY_CHECK(assigned_job_id_ == job_id) + << "The worker " << worker_id_ << " is already assigned to job " + << assigned_job_id_ << ". It cannot be reassigned to job " << job_id; + } else { + assigned_job_id_ = job_id; + RAY_LOG(INFO) << "Assigned worker " << worker_id_ << " to job " << job_id; + } + } +} const JobID &Worker::GetAssignedJobId() const { return assigned_job_id_; } diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index ed08bc336..167fbdcfe 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -54,9 +54,8 @@ namespace ray { namespace raylet { -/// A constructor that initializes a worker pool with num_workers workers for -/// each language. WorkerPool::WorkerPool(boost::asio::io_service &io_service, int num_workers, + int num_initial_python_workers_for_first_job, int maximum_startup_concurrency, int min_worker_port, int max_worker_port, std::shared_ptr gcs_client, const WorkerCommandMap &worker_commands, @@ -66,7 +65,12 @@ WorkerPool::WorkerPool(boost::asio::io_service &io_service, int num_workers, maximum_startup_concurrency_(maximum_startup_concurrency), gcs_client_(std::move(gcs_client)), raylet_config_(raylet_config), - starting_worker_timeout_callback_(starting_worker_timeout_callback) { + starting_worker_timeout_callback_(starting_worker_timeout_callback), + first_job_registered_python_worker_count_(0), + first_job_driver_wait_num_python_workers_(std::min( + num_initial_python_workers_for_first_job, maximum_startup_concurrency)), + num_initial_python_workers_for_first_job_( + num_initial_python_workers_for_first_job) { RAY_CHECK(maximum_startup_concurrency > 0); #ifndef _WIN32 // Ignore SIGCHLD signals. If we don't do this, then worker processes will @@ -76,25 +80,29 @@ WorkerPool::WorkerPool(boost::asio::io_service &io_service, int num_workers, for (const auto &entry : worker_commands) { // Initialize the pool state for this language. auto &state = states_by_lang_[entry.first]; - switch (entry.first) { - case Language::PYTHON: - state.num_workers_per_process = - RayConfig::instance().num_workers_per_process_python(); - break; - case Language::JAVA: - state.num_workers_per_process = - RayConfig::instance().num_workers_per_process_java(); - break; - default: - RAY_LOG(FATAL) << "The number of workers per process for " - << Language_Name(entry.first) << " worker is not set."; + if (!RayConfig::instance().enable_multi_tenancy()) { + switch (entry.first) { + case Language::PYTHON: + state.num_workers_per_process = + RayConfig::instance().num_workers_per_process_python(); + break; + case Language::JAVA: + state.num_workers_per_process = + RayConfig::instance().num_workers_per_process_java(); + break; + default: + RAY_LOG(FATAL) << "The number of workers per process for " + << Language_Name(entry.first) << " worker is not set."; + } + RAY_CHECK(state.num_workers_per_process > 0) + << "Number of workers per process of language " << Language_Name(entry.first) + << " must be positive."; + state.multiple_for_warning = + std::max(state.num_workers_per_process, + std::max(num_workers, maximum_startup_concurrency)); + } else { + state.multiple_for_warning = maximum_startup_concurrency; } - RAY_CHECK(state.num_workers_per_process > 0) - << "Number of workers per process of language " << Language_Name(entry.first) - << " must be positive."; - state.multiple_for_warning = - std::max(state.num_workers_per_process, - std::max(num_workers, maximum_startup_concurrency)); // Set worker command for this language. state.worker_command = entry.second; RAY_CHECK(!state.worker_command.empty()) << "Worker command must not be empty."; @@ -111,16 +119,19 @@ WorkerPool::WorkerPool(boost::asio::io_service &io_service, int num_workers, free_ports_->push(port); } } - Start(num_workers); + if (!RayConfig::instance().enable_multi_tenancy()) { + Start(num_workers); + } } void WorkerPool::Start(int num_workers) { + RAY_CHECK(!RayConfig::instance().enable_multi_tenancy()); for (auto &entry : states_by_lang_) { auto &state = entry.second; int num_worker_processes = static_cast( std::ceil(static_cast(num_workers) / state.num_workers_per_process)); for (int i = 0; i < num_worker_processes; i++) { - StartWorkerProcess(entry.first); + StartWorkerProcess(entry.first, JobID::Nil()); } } } @@ -154,8 +165,20 @@ uint32_t WorkerPool::Size(const Language &language) const { } } -Process WorkerPool::StartWorkerProcess(const Language &language, - const std::vector &dynamic_options) { +Process WorkerPool::StartWorkerProcess(const Language &language, const JobID &job_id, + std::vector dynamic_options) { + rpc::JobConfig *job_config = nullptr; + if (RayConfig::instance().enable_multi_tenancy()) { + RAY_CHECK(!job_id.IsNil()); + auto it = unfinished_jobs_.find(job_id); + if (it == unfinished_jobs_.end()) { + RAY_LOG(DEBUG) << "Job config of job " << job_id << " are not local yet."; + // Will reschedule ready tasks in `NodeManager::HandleJobStarted`. + return Process(); + } + job_config = &it->second; + } + auto &state = GetStateForLanguage(language); // If we are already starting up too many workers, then return without starting // more. @@ -175,11 +198,21 @@ Process WorkerPool::StartWorkerProcess(const Language &language, << state.idle_actor.size() << " actor workers, and " << state.idle.size() << " non-actor workers"; - int workers_to_start; + int workers_to_start = 1; if (dynamic_options.empty()) { - workers_to_start = state.num_workers_per_process; - } else { - workers_to_start = 1; + if (!RayConfig::instance().enable_multi_tenancy()) { + workers_to_start = state.num_workers_per_process; + } else if (language == Language::JAVA) { + workers_to_start = job_config->num_java_workers_per_process(); + } + } + + if (RayConfig::instance().enable_multi_tenancy() && + !job_config->jvm_options().empty()) { + // Note that we push the item to the front of the vector to make + // sure this is the freshest option than others. + dynamic_options.insert(dynamic_options.begin(), job_config->jvm_options().begin(), + job_config->jvm_options().end()); } // Extract pointers from the worker command to pass into execvp. @@ -216,12 +249,17 @@ Process WorkerPool::StartWorkerProcess(const Language &language, arg.append(entry.second); worker_command_args.push_back(arg); } - // The value of `num_workers_per_process_java` may change depends on whether - // dynamic options is empty, so we can't use the value in `RayConfig`. We always - // overwrite the value here. - worker_command_args.push_back( - "-Dray.raylet.config.num_workers_per_process_java=" + - std::to_string(workers_to_start)); + if (!RayConfig::instance().enable_multi_tenancy()) { + // The value of `num_workers_per_process_java` may change depends on whether + // dynamic options is empty, so we can't use the value in `RayConfig`. We always + // overwrite the value here. + worker_command_args.push_back( + "-Dray.raylet.config.num_workers_per_process_java=" + + std::to_string(workers_to_start)); + } else { + worker_command_args.push_back("-Dray.job.num-java-workers-per-process=" + + std::to_string(workers_to_start)); + } break; default: RAY_LOG(FATAL) @@ -241,7 +279,14 @@ Process WorkerPool::StartWorkerProcess(const Language &language, << " placeholder is not found in worker command."; } + // TODO(kfstorm): Set up environment variables in a later PR. Process proc = StartProcess(worker_command_args); + if (RayConfig::instance().enable_multi_tenancy()) { + // If the pid is reused between processes, the old process must have exited. + // So it's safe to bind the pid with another job ID. + RAY_LOG(DEBUG) << "Worker process " << proc.GetId() << " is bound to job " << job_id; + state.worker_pids_to_assigned_jobs[proc.GetId()] = job_id; + } RAY_LOG(DEBUG) << "Started worker process of " << workers_to_start << " worker(s) with pid " << proc.GetId(); MonitorStartingWorkerProcess(proc, language); @@ -327,34 +372,120 @@ void WorkerPool::MarkPortAsFree(int port) { } } +void WorkerPool::HandleJobStarted(const JobID &job_id, const rpc::JobConfig &job_config) { + unfinished_jobs_[job_id] = job_config; +} + +void WorkerPool::HandleJobFinished(const JobID &job_id) { + unfinished_jobs_.erase(job_id); +} + Status WorkerPool::RegisterWorker(const std::shared_ptr &worker, - pid_t pid, int *port) { + pid_t pid, + std::function send_reply_callback) { + RAY_CHECK(worker); + + // The port that this worker's gRPC server should listen on. 0 if the worker + // should bind on a random port. + int port; + Status status; + auto &state = GetStateForLanguage(worker->GetLanguage()); auto it = state.starting_worker_processes.find(Process::FromPid(pid)); if (it == state.starting_worker_processes.end()) { RAY_LOG(WARNING) << "Received a register request from an unknown worker " << pid; - return Status::Invalid("Unknown worker"); - } - RAY_RETURN_NOT_OK(GetNextFreePort(port)); - RAY_LOG(DEBUG) << "Registering worker with pid " << pid << ", port: " << *port; - worker->SetAssignedPort(*port); - worker->SetProcess(it->first); - it->second--; - if (it->second == 0) { - state.starting_worker_processes.erase(it); + // Return -1 to signal to the worker that registration failed. + port = -1; + status = Status::Invalid("Unknown worker"); + } else { + RAY_RETURN_NOT_OK(GetNextFreePort(&port)); + RAY_LOG(DEBUG) << "Registering worker with pid " << pid << ", port: " << port; + worker->SetAssignedPort(port); + worker->SetProcess(it->first); + it->second--; + if (it->second == 0) { + state.starting_worker_processes.erase(it); + } + + RAY_CHECK(worker->GetProcess().GetId() == pid); + state.registered_workers.insert(worker); + + if (RayConfig::instance().enable_multi_tenancy()) { + auto dedicated_workers_it = state.worker_pids_to_assigned_jobs.find(pid); + RAY_CHECK(dedicated_workers_it != state.worker_pids_to_assigned_jobs.end()); + auto job_id = dedicated_workers_it->second; + worker->AssignJobId(job_id); + // We don't call state.worker_pids_to_assigned_jobs.erase(job_id) here + // because we allow multi-workers per worker process. + + // This is a workaround to finish driver registration after all initial workers are + // registered to Raylet if and only if Raylet is started by a Python driver and the + // job config is not set in `ray.init(...)`. + if (first_job_ == job_id && worker->GetLanguage() == Language::PYTHON) { + if (++first_job_registered_python_worker_count_ == + first_job_driver_wait_num_python_workers_) { + if (first_job_send_register_client_reply_to_driver_) { + first_job_send_register_client_reply_to_driver_(); + first_job_send_register_client_reply_to_driver_ = nullptr; + } + } + } + } + + status = Status::OK(); } - state.registered_workers.emplace(std::move(worker)); - return Status::OK(); + // Send the reply immediately for worker registrations. + if (send_reply_callback) { + send_reply_callback(port); + } + return status; } Status WorkerPool::RegisterDriver(const std::shared_ptr &driver, - int *port) { + const JobID &job_id, const rpc::JobConfig &job_config, + std::function send_reply_callback) { + int port; RAY_CHECK(!driver->GetAssignedTaskId().IsNil()); - RAY_RETURN_NOT_OK(GetNextFreePort(port)); - driver->SetAssignedPort(*port); + RAY_RETURN_NOT_OK(GetNextFreePort(&port)); + driver->SetAssignedPort(port); auto &state = GetStateForLanguage(driver->GetLanguage()); state.registered_drivers.insert(std::move(driver)); + driver->AssignJobId(job_id); + unfinished_jobs_[job_id] = job_config; + + if (send_reply_callback) { + // This is a workaround to start initial workers on this node if and only if Raylet is + // started by a Python driver and the job config is not set in `ray.init(...)`. + // Invoke the `send_reply_callback` later to only finish driver + // registration after all initial workers are registered to Raylet. + bool delay_callback = false; + // Multi-tenancy is enabled. + if (RayConfig().instance().enable_multi_tenancy()) { + // If this is the first job. + if (first_job_.IsNil()) { + first_job_ = job_id; + // If the number of Python workers we need to wait is positive. + if (num_initial_python_workers_for_first_job_ > 0) { + delay_callback = true; + // Start initial Python workers for the first job. + for (int i = 0; i < num_initial_python_workers_for_first_job_; i++) { + StartWorkerProcess(Language::PYTHON, job_id); + } + } + } + } + + if (delay_callback) { + RAY_CHECK(!first_job_send_register_client_reply_to_driver_); + first_job_send_register_client_reply_to_driver_ = [send_reply_callback, port]() { + send_reply_callback(port); + }; + } else { + send_reply_callback(port); + } + } + return Status::OK(); } @@ -424,8 +555,8 @@ std::shared_ptr WorkerPool::PopWorker( } else if (!HasPendingWorkerForTask(task_spec.GetLanguage(), task_spec.TaskId())) { // We are not pending a registration from a worker for this task, // so start a new worker process for this task. - proc = - StartWorkerProcess(task_spec.GetLanguage(), task_spec.DynamicWorkerOptions()); + proc = StartWorkerProcess(task_spec.GetLanguage(), task_spec.JobId(), + task_spec.DynamicWorkerOptions()); if (proc.IsValid()) { state.dedicated_workers_to_tasks[proc] = task_spec.TaskId(); state.tasks_to_dedicated_workers[task_spec.TaskId()] = proc; @@ -433,13 +564,30 @@ std::shared_ptr WorkerPool::PopWorker( } } else if (!task_spec.IsActorTask()) { // Code path of normal task or actor creation task without dynamic worker options. - if (!state.idle.empty()) { - worker = std::move(*state.idle.begin()); - state.idle.erase(state.idle.begin()); + if (!RayConfig::instance().enable_multi_tenancy()) { + if (!state.idle.empty()) { + worker = std::move(*state.idle.begin()); + state.idle.erase(state.idle.begin()); + } else { + // There are no more non-actor workers available to execute this task. + // Start a new worker process. + proc = StartWorkerProcess(task_spec.GetLanguage(), JobID::Nil()); + } } else { - // There are no more non-actor workers available to execute this task. - // Start a new worker process. - proc = StartWorkerProcess(task_spec.GetLanguage()); + // Find an available worker which is already assigned to this job. + for (auto it = state.idle.begin(); it != state.idle.end(); it++) { + if ((*it)->GetAssignedJobId() != task_spec.JobId()) { + continue; + } + worker = std::move(*it); + state.idle.erase(it); + break; + } + if (worker == nullptr) { + // There are no more non-actor workers available to execute this task. + // Start a new worker process. + proc = StartWorkerProcess(task_spec.GetLanguage(), task_spec.JobId()); + } } } else { // Code path of actor task. @@ -455,6 +603,9 @@ std::shared_ptr WorkerPool::PopWorker( WarnAboutSize(); } + if (RayConfig::instance().enable_multi_tenancy() && worker) { + RAY_CHECK(worker->GetAssignedJobId() == task_spec.JobId()); + } return worker; } diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index ee25afe70..64c497d42 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -73,6 +73,8 @@ class WorkerPool : public WorkerPoolInterface { /// the pool. /// /// \param num_workers The number of workers to start, per language. + /// \param num_initial_python_workers_for_first_job The number of initial Python + /// workers for the first job. /// \param maximum_startup_concurrency The maximum number of worker processes /// that can be started in parallel (typically this should be set to the number of CPU /// resources on the machine). @@ -86,6 +88,7 @@ class WorkerPool : public WorkerPoolInterface { /// \param starting_worker_timeout_callback The callback that will be triggered once /// it times out to start a worker. WorkerPool(boost::asio::io_service &io_service, int num_workers, + int num_initial_python_workers_for_first_job, int maximum_startup_concurrency, int min_worker_port, int max_worker_port, std::shared_ptr gcs_client, const WorkerCommandMap &worker_commands, @@ -95,24 +98,42 @@ class WorkerPool : public WorkerPoolInterface { /// Destructor responsible for freeing a set of workers owned by this class. virtual ~WorkerPool(); + /// Handles the event that a job is started. + /// + /// \param job_id ID of the started job. + /// \param job_config The config of the started job. + /// \return Void + void HandleJobStarted(const JobID &job_id, const rpc::JobConfig &job_config); + + /// Handles the event that a job is finished. + /// + /// \param job_id ID of the finished job. + /// \return Void. + void HandleJobFinished(const JobID &job_id); + /// Register a new worker. The Worker should be added by the caller to the /// pool after it becomes idle (e.g., requests a work assignment). /// /// \param[in] worker The worker to be registered. /// \param[in] pid The PID of the worker. - /// \param[out] port The port that this worker's gRPC server should listen on. + /// \param[in] send_reply_callback The callback to invoke after registration is + /// finished/failed. /// Returns 0 if the worker should bind on a random port. /// \return If the registration is successful. Status RegisterWorker(const std::shared_ptr &worker, pid_t pid, - int *port); + std::function send_reply_callback); /// Register a new driver. /// /// \param[in] worker The driver to be registered. - /// \param[out] port The port that this driver's gRPC server should listen on. - /// Returns 0 if the driver should bind on a random port. + /// \param[in] job_id The job ID of the driver. + /// \param[in] job_config The config of the job. + /// \param[in] send_reply_callback The callback to invoke after registration is + /// finished/failed. /// \return If the registration is successful. - Status RegisterDriver(const std::shared_ptr &worker, int *port); + Status RegisterDriver(const std::shared_ptr &worker, + const JobID &job_id, const rpc::JobConfig &job_config, + std::function send_reply_callback); /// Get the client connection's registered worker. /// @@ -208,11 +229,12 @@ class WorkerPool : public WorkerPoolInterface { /// any workers. /// /// \param language Which language this worker process should be. + /// \param job_id The ID of the job to which the started worker process belongs. /// \param dynamic_options The dynamic options that we should add for worker command. /// \return The id of the process that we started if it's positive, /// otherwise it means we didn't start a process. - Process StartWorkerProcess(const Language &language, - const std::vector &dynamic_options = {}); + Process StartWorkerProcess(const Language &language, const JobID &job_id, + std::vector dynamic_options = {}); /// The implementation of how to start a new worker process with command arguments. /// The lifetime of the process is tied to that of the returned object, @@ -251,6 +273,8 @@ class WorkerPool : public WorkerPoolInterface { std::unordered_map dedicated_workers_to_tasks; /// A map for speeding up looking up the pending worker for the given task. std::unordered_map tasks_to_dedicated_workers; + /// A map for looking up the owner JobId by the pid of worker. + std::unordered_map worker_pids_to_assigned_jobs; /// We'll push a warning to the user every time a multiple of this many /// worker processes has been started. int multiple_for_warning; @@ -307,6 +331,25 @@ class WorkerPool : public WorkerPoolInterface { /// The callback that will be triggered once it times out to start a worker. std::function starting_worker_timeout_callback_; FRIEND_TEST(WorkerPoolTest, InitialWorkerProcessCount); + + /// The Job ID of the firstly received job. + JobID first_job_; + + /// The callback to send RegisterClientReply to the driver of the first job. + std::function first_job_send_register_client_reply_to_driver_; + + /// The number of registered workers of the first job. + int first_job_registered_python_worker_count_; + + /// The umber of initial Python workers to wait for the first job before the driver + /// receives RegisterClientReply. + int first_job_driver_wait_num_python_workers_; + + /// The number of initial Python workers for the first job. + int num_initial_python_workers_for_first_job_; + + /// This map tracks the latest infos of unfinished jobs. + absl::flat_hash_map unfinished_jobs_; }; } // namespace raylet diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 9027c6dcb..d8cbbfea3 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -26,21 +26,15 @@ namespace raylet { int NUM_WORKERS_PER_PROCESS_JAVA = 3; int MAXIMUM_STARTUP_CONCURRENCY = 5; +JobID JOB_ID = JobID::FromInt(1); std::vector LANGUAGES = {Language::PYTHON, Language::JAVA}; class WorkerPoolMock : public WorkerPool { public: - WorkerPoolMock(boost::asio::io_service &io_service) - : WorkerPoolMock( - io_service, - {{Language::PYTHON, {"dummy_py_worker_command"}}, - {Language::JAVA, - {"dummy_java_worker_command", "RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER"}}}) {} - explicit WorkerPoolMock(boost::asio::io_service &io_service, const WorkerCommandMap &worker_commands) - : WorkerPool(io_service, 0, MAXIMUM_STARTUP_CONCURRENCY, 0, 0, nullptr, + : WorkerPool(io_service, 0, 0, MAXIMUM_STARTUP_CONCURRENCY, 0, 0, nullptr, worker_commands, {}, []() {}), last_worker_process_() { states_by_lang_[ray::Language::JAVA].num_workers_per_process = @@ -94,14 +88,22 @@ class WorkerPoolMock : public WorkerPool { std::unordered_map> worker_commands_by_proc_; }; -class WorkerPoolTest : public ::testing::Test { +class WorkerPoolTest : public ::testing::TestWithParam { public: WorkerPoolTest() : error_message_type_(1), client_call_manager_(io_service_) { - worker_pool_ = std::unique_ptr(new WorkerPoolMock(io_service_)); + bool enable_multi_tenancy = GetParam(); + RayConfig::instance().initialize( + {{"enable_multi_tenancy", std::to_string(enable_multi_tenancy)}, + {"num_workers_per_process_java", std::to_string(NUM_WORKERS_PER_PROCESS_JAVA)}}); + SetWorkerCommands( + {{Language::PYTHON, {"dummy_py_worker_command"}}, + {Language::JAVA, + {"dummy_java_worker_command", "RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER"}}}); } std::shared_ptr CreateWorker( - Process proc, const Language &language = Language::PYTHON) { + Process proc, const Language &language = Language::PYTHON, + const JobID &job_id = JOB_ID) { std::function client_handler = [this](ClientConnection &client) { HandleNewClient(client); }; std::function, int64_t, @@ -119,15 +121,28 @@ class WorkerPoolTest : public ::testing::Test { WorkerID::FromRandom(), language, "127.0.0.1", client, client_call_manager_); std::shared_ptr worker = std::dynamic_pointer_cast(worker_); + worker->AssignJobId(job_id); if (!proc.IsNull()) { worker->SetProcess(proc); } return worker; } + std::shared_ptr RegisterDriver( + const Language &language = Language::PYTHON, const JobID &job_id = JOB_ID, + const rpc::JobConfig &job_config = rpc::JobConfig()) { + auto driver = CreateWorker(Process::CreateNewDummy(), Language::PYTHON, job_id); + driver->AssignTaskId(TaskID::ForDriverTask(job_id)); + RAY_CHECK_OK(worker_pool_->RegisterDriver(driver, job_id, job_config, nullptr)); + return driver; + } + void SetWorkerCommands(const WorkerCommandMap &worker_commands) { worker_pool_ = std::unique_ptr(new WorkerPoolMock(io_service_, worker_commands)); + rpc::JobConfig job_config; + job_config.set_num_java_workers_per_process(NUM_WORKERS_PER_PROCESS_JAVA); + RegisterDriver(Language::PYTHON, JOB_ID, job_config); } void TestStartupWorkerProcessCount(Language language, int num_workers_per_process, @@ -139,7 +154,7 @@ class WorkerPoolTest : public ::testing::Test { static_cast(desired_initial_worker_process_count)); Process last_started_worker_process; for (int i = 0; i < desired_initial_worker_process_count; i++) { - worker_pool_->StartWorkerProcess(language); + worker_pool_->StartWorkerProcess(language, JOB_ID); ASSERT_TRUE(worker_pool_->NumWorkerProcessesStarting() <= expected_worker_process_count); Process prev = worker_pool_->LastStartedWorkerProcess(); @@ -172,9 +187,10 @@ class WorkerPoolTest : public ::testing::Test { static inline TaskSpecification ExampleTaskSpec( const ActorID actor_id = ActorID::Nil(), const Language &language = Language::PYTHON, - const ActorID actor_creation_id = ActorID::Nil(), + const JobID &job_id = JOB_ID, const ActorID actor_creation_id = ActorID::Nil(), const std::vector &dynamic_worker_options = {}) { rpc::TaskSpec message; + message.set_job_id(job_id.Binary()); message.set_language(language); if (!actor_id.IsNil()) { message.set_type(TaskType::ACTOR_TASK); @@ -191,7 +207,17 @@ static inline TaskSpecification ExampleTaskSpec( return TaskSpecification(std::move(message)); } -TEST_F(WorkerPoolTest, CompareWorkerProcessObjects) { +static inline std::string GetNumJavaWorkersPerProcessSystemProperty(int num) { + std::string key; + if (RayConfig::instance().enable_multi_tenancy()) { + key = "ray.job.num-java-workers-per-process"; + } else { + key = "ray.raylet.config.num_workers_per_process_java"; + } + return std::string("-D") + key + "=" + std::to_string(num); +} + +TEST_P(WorkerPoolTest, CompareWorkerProcessObjects) { typedef Process T; T a(T::CreateNewDummy()), b(T::CreateNewDummy()), empty = T(); ASSERT_TRUE(empty.IsNull()); @@ -205,8 +231,8 @@ TEST_F(WorkerPoolTest, CompareWorkerProcessObjects) { ASSERT_TRUE(!std::equal_to()(a, empty)); } -TEST_F(WorkerPoolTest, HandleWorkerRegistration) { - Process proc = worker_pool_->StartWorkerProcess(Language::JAVA); +TEST_P(WorkerPoolTest, HandleWorkerRegistration) { + Process proc = worker_pool_->StartWorkerProcess(Language::JAVA, JOB_ID); std::vector> workers; for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) { workers.push_back(CreateWorker(Process(), Language::JAVA)); @@ -217,8 +243,7 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) { ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), 1); // Check that we cannot lookup the worker before it's registered. ASSERT_EQ(worker_pool_->GetRegisteredWorker(worker->Connection()), nullptr); - int port; - RAY_CHECK_OK(worker_pool_->RegisterWorker(worker, proc.GetId(), &port)); + RAY_CHECK_OK(worker_pool_->RegisterWorker(worker, proc.GetId(), nullptr)); // Check that we can lookup the worker after it's registered. ASSERT_EQ(worker_pool_->GetRegisteredWorker(worker->Connection()), worker); } @@ -231,30 +256,34 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) { } } -TEST_F(WorkerPoolTest, StartupPythonWorkerProcessCount) { +TEST_P(WorkerPoolTest, StartupPythonWorkerProcessCount) { TestStartupWorkerProcessCount(Language::PYTHON, 1, {"dummy_py_worker_command"}); } -TEST_F(WorkerPoolTest, StartupJavaWorkerProcessCount) { +TEST_P(WorkerPoolTest, StartupJavaWorkerProcessCount) { TestStartupWorkerProcessCount( Language::JAVA, NUM_WORKERS_PER_PROCESS_JAVA, {"dummy_java_worker_command", - std::string("-Dray.raylet.config.num_workers_per_process_java=") + - std::to_string(NUM_WORKERS_PER_PROCESS_JAVA)}); + GetNumJavaWorkersPerProcessSystemProperty(NUM_WORKERS_PER_PROCESS_JAVA)}); } -TEST_F(WorkerPoolTest, InitialWorkerProcessCount) { - worker_pool_->Start(1); - // Here we try to start only 1 worker for each worker language. But since each Java - // worker process contains exactly NUM_WORKERS_PER_PROCESS_JAVA (3) workers here, - // it's expected to see 3 workers for Java and 1 worker for Python, instead of 1 for - // each worker language. - ASSERT_NE(worker_pool_->NumWorkersStarting(), 1 * LANGUAGES.size()); - ASSERT_EQ(worker_pool_->NumWorkersStarting(), 1 + NUM_WORKERS_PER_PROCESS_JAVA); - ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), LANGUAGES.size()); +TEST_P(WorkerPoolTest, InitialWorkerProcessCount) { + if (!RayConfig::instance().enable_multi_tenancy()) { + worker_pool_->Start(1); + // Here we try to start only 1 worker for each worker language. But since each Java + // worker process contains exactly NUM_WORKERS_PER_PROCESS_JAVA (3) workers here, + // it's expected to see 3 workers for Java and 1 worker for Python, instead of 1 for + // each worker language. + ASSERT_NE(worker_pool_->NumWorkersStarting(), 1 * LANGUAGES.size()); + ASSERT_EQ(worker_pool_->NumWorkersStarting(), 1 + NUM_WORKERS_PER_PROCESS_JAVA); + ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), LANGUAGES.size()); + } else { + ASSERT_EQ(worker_pool_->NumWorkersStarting(), 0); + ASSERT_EQ(worker_pool_->NumWorkerProcessesStarting(), 0); + } } -TEST_F(WorkerPoolTest, HandleWorkerPushPop) { +TEST_P(WorkerPoolTest, HandleWorkerPushPop) { // Try to pop a worker from the empty pool and make sure we don't get one. std::shared_ptr popped_worker; const auto task_spec = ExampleTaskSpec(); @@ -281,7 +310,7 @@ TEST_F(WorkerPoolTest, HandleWorkerPushPop) { ASSERT_EQ(popped_worker, nullptr); } -TEST_F(WorkerPoolTest, PopActorWorker) { +TEST_P(WorkerPoolTest, PopActorWorker) { // Create a worker. auto worker = CreateWorker(Process::CreateNewDummy()); // Add the worker to the pool. @@ -290,8 +319,7 @@ TEST_F(WorkerPoolTest, PopActorWorker) { // Assign an actor ID to the worker. const auto task_spec = ExampleTaskSpec(); auto actor = worker_pool_->PopWorker(task_spec); - const auto job_id = JobID::FromInt(1); - auto actor_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1); + auto actor_id = ActorID::Of(JOB_ID, TaskID::ForDriverTask(JOB_ID), 1); actor->AssignActorId(actor_id); worker_pool_->PushWorker(actor); @@ -304,7 +332,7 @@ TEST_F(WorkerPoolTest, PopActorWorker) { ASSERT_EQ(actor->GetActorId(), actor_id); } -TEST_F(WorkerPoolTest, PopWorkersOfMultipleLanguages) { +TEST_P(WorkerPoolTest, PopWorkersOfMultipleLanguages) { // Create a Python Worker, and add it to the pool auto py_worker = CreateWorker(Process::CreateNewDummy(), Language::PYTHON); worker_pool_->PushWorker(py_worker); @@ -322,26 +350,94 @@ TEST_F(WorkerPoolTest, PopWorkersOfMultipleLanguages) { ASSERT_NE(worker_pool_->PopWorker(java_task_spec), nullptr); } -TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { +TEST_P(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { const std::vector java_worker_command = { "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_0", "dummy_java_worker_command", "RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER", "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_1"}; SetWorkerCommands({{Language::PYTHON, {"dummy_py_worker_command"}}, {Language::JAVA, java_worker_command}}); - const auto job_id = JobID::FromInt(1); TaskSpecification task_spec = ExampleTaskSpec( - ActorID::Nil(), Language::JAVA, - ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1), {"test_op_0", "test_op_1"}); - worker_pool_->StartWorkerProcess(Language::JAVA, task_spec.DynamicWorkerOptions()); + ActorID::Nil(), Language::JAVA, JOB_ID, + ActorID::Of(JOB_ID, TaskID::ForDriverTask(JOB_ID), 1), {"test_op_0", "test_op_1"}); + worker_pool_->StartWorkerProcess(Language::JAVA, JOB_ID, + task_spec.DynamicWorkerOptions()); const auto real_command = worker_pool_->GetWorkerCommand(worker_pool_->LastStartedWorkerProcess()); ASSERT_EQ(real_command, - std::vector( - {"test_op_0", "dummy_java_worker_command", - "-Dray.raylet.config.num_workers_per_process_java=1", "test_op_1"})); + std::vector({"test_op_0", "dummy_java_worker_command", + GetNumJavaWorkersPerProcessSystemProperty(1), + "test_op_1"})); } +TEST_P(WorkerPoolTest, PopWorkerMultiTenancy) { + if (!RayConfig::instance().enable_multi_tenancy()) { + return; + } + + auto job_id1 = JOB_ID; + auto job_id2 = JobID::FromInt(2); + ASSERT_NE(job_id1, job_id2); + JobID job_ids[] = {job_id1, job_id2}; + + // The driver of job 1 is already registered. Here we register the driver for job 2. + RegisterDriver(Language::PYTHON, job_id2); + + // Register 2 workers for each job. + for (auto job_id : job_ids) { + for (int i = 0; i < 2; i++) { + auto worker = CreateWorker(Process::CreateNewDummy(), Language::PYTHON, job_id); + worker_pool_->PushWorker(worker); + } + } + + std::unordered_set worker_ids; + for (int round = 0; round < 2; round++) { + std::vector> workers; + + // Pop workers for actor (creation) tasks. + for (auto job_id : job_ids) { + auto actor_id = ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 1); + // For the first round, we pop for actor creation tasks. + // For the second round, we pop for actor tasks. + auto task_spec = + ExampleTaskSpec(round == 0 ? ActorID::Nil() : actor_id, Language::PYTHON, + job_id, round == 0 ? actor_id : ActorID::Nil()); + auto worker = worker_pool_->PopWorker(task_spec); + ASSERT_TRUE(worker); + ASSERT_EQ(worker->GetAssignedJobId(), job_id); + if (round == 0) { + worker->AssignActorId(actor_id); + } + workers.push_back(worker); + } + + // Pop workers for normal tasks. + for (auto job_id : job_ids) { + auto task_spec = ExampleTaskSpec(ActorID::Nil(), Language::PYTHON, job_id); + auto worker = worker_pool_->PopWorker(task_spec); + ASSERT_TRUE(worker); + ASSERT_EQ(worker->GetAssignedJobId(), job_id); + workers.push_back(worker); + } + + // Return all workers. + for (auto worker : workers) { + worker_pool_->PushWorker(worker); + if (round == 0) { + // For the first round, all workers are new. + ASSERT_TRUE(worker_ids.insert(worker->WorkerId()).second); + } else { + // For the second round, all workers are existing ones. + ASSERT_TRUE(worker_ids.count(worker->WorkerId()) > 0); + } + } + } +} + +INSTANTIATE_TEST_CASE_P(WorkerPoolMultiTenancyTest, WorkerPoolTest, + ::testing::Values(true, false)); + } // namespace raylet } // namespace ray diff --git a/src/ray/raylet_client/raylet_client.cc b/src/ray/raylet_client/raylet_client.cc index b2b2020c3..e9ff24a0c 100644 --- a/src/ray/raylet_client/raylet_client.cc +++ b/src/ray/raylet_client/raylet_client.cc @@ -83,8 +83,12 @@ raylet::RayletClient::RayletClient( const std::string &raylet_socket, const WorkerID &worker_id, bool is_worker, const JobID &job_id, const Language &language, const std::string &ip_address, ClientID *raylet_id, int *port, - std::unordered_map *internal_config) - : grpc_client_(std::move(grpc_client)), worker_id_(worker_id), job_id_(job_id) { + std::unordered_map *internal_config, + const std::string &job_config) + : grpc_client_(std::move(grpc_client)), + worker_id_(worker_id), + job_id_(job_id), + job_config_(job_config) { // For C++14, we could use std::make_unique conn_ = std::unique_ptr( new raylet::RayletConnection(io_service, raylet_socket, -1, -1)); @@ -92,7 +96,7 @@ raylet::RayletClient::RayletClient( flatbuffers::FlatBufferBuilder fbb; auto message = protocol::CreateRegisterClientRequest( fbb, is_worker, to_flatbuf(fbb, worker_id), getpid(), to_flatbuf(fbb, job_id), - language, fbb.CreateString(ip_address)); + language, fbb.CreateString(ip_address), /*port=*/0, fbb.CreateString(job_config_)); fbb.Finish(message); // Register the process ID with the raylet. // NOTE(swang): If raylet exits and we are registered as a worker, we will get killed. diff --git a/src/ray/raylet_client/raylet_client.h b/src/ray/raylet_client/raylet_client.h index cbb42edbe..78421f43a 100644 --- a/src/ray/raylet_client/raylet_client.h +++ b/src/ray/raylet_client/raylet_client.h @@ -177,7 +177,8 @@ class RayletClient : public PinObjectsInterface, const std::string &raylet_socket, const WorkerID &worker_id, bool is_worker, const JobID &job_id, const Language &language, const std::string &ip_address, ClientID *raylet_id, int *port, - std::unordered_map *internal_config); + std::unordered_map *internal_config, + const std::string &job_config); /// Connect to the raylet via grpc only. /// @@ -369,6 +370,8 @@ class RayletClient : public PinObjectsInterface, std::shared_ptr grpc_client_; const WorkerID worker_id_; const JobID job_id_; + const std::string job_config_; + /// A map from resource name to the resource IDs that are currently reserved /// for this worker. Each pair consists of the resource ID and the fraction /// of that resource allocated for this worker.