[Core] Multi-tenancy: Job isolation & implement per job config (except for env variables) (#9500)

This commit is contained in:
Kai Yang
2020-08-04 15:51:29 +08:00
committed by GitHub
parent 28b1f7710c
commit 27cd323ce1
35 changed files with 969 additions and 184 deletions
+3 -1
View File
@@ -658,7 +658,8 @@ cdef class CoreWorker:
def __cinit__(self, is_driver, store_socket, raylet_socket,
JobID job_id, GcsClientOptions gcs_options, log_dir,
node_ip_address, node_manager_port, raylet_ip_address,
local_mode, driver_name, stdout_file, stderr_file):
local_mode, driver_name, stdout_file, stderr_file,
serialized_job_config):
self.is_driver = is_driver
self.is_local_mode = local_mode
@@ -688,6 +689,7 @@ cdef class CoreWorker:
options.num_workers = 1
options.kill_main = kill_main_task
options.terminate_asyncio_thread = terminate_asyncio_thread
options.serialized_job_config = serialized_job_config
CCoreWorkerProcess.Initialize(options)
+2
View File
@@ -3,6 +3,7 @@ from ray.core.generated.gcs_pb2 import (
ActorTableData,
GcsNodeInfo,
JobTableData,
JobConfig,
ErrorTableData,
ErrorType,
GcsEntry,
@@ -25,6 +26,7 @@ __all__ = [
"ActorTableData",
"GcsNodeInfo",
"JobTableData",
"JobConfig",
"ErrorTableData",
"ErrorType",
"GcsEntry",
+1
View File
@@ -226,6 +226,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
(c_bool() nogil) kill_main
CCoreWorkerOptions()
(void() nogil) terminate_asyncio_thread
c_string serialized_job_config
cdef cppclass CCoreWorkerProcess "ray::CoreWorkerProcess":
@staticmethod
+38
View File
@@ -0,0 +1,38 @@
import ray
class JobConfig:
"""A class used to store the configurations of a job.
Attributes:
worker_env (dict): Environment variables to be set on worker
processes.
num_java_workers_per_process (int): The number of java workers per
worker process.
jvm_options (str[]): The jvm options for java workers of the job.
"""
def __init__(
self,
worker_env=None,
num_java_workers_per_process=10,
jvm_options=None,
):
if worker_env is None:
self.worker_env = dict()
else:
self.worker_env = worker_env
self.num_java_workers_per_process = num_java_workers_per_process
if jvm_options is None:
self.jvm_options = []
else:
self.jvm_options = jvm_options
def serialize(self):
job_config = ray.gcs_utils.JobConfig()
for key in self.worker_env:
job_config.worker_env[key] = self.worker_env[key]
job_config.num_java_workers_per_process = (
self.num_java_workers_per_process)
job_config.jvm_options.extend(self.jvm_options)
return job_config.SerializeToString()
+3 -1
View File
@@ -679,7 +679,9 @@ class Node:
huge_pages=self._ray_params.huge_pages,
fate_share=self.kernel_fate_share,
socket_to_use=self.socket,
head_node=self.head)
head_node=self.head,
start_initial_python_workers_for_first_job=self._ray_params.
start_initial_python_workers_for_first_job)
assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes
self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info]
+5
View File
@@ -94,6 +94,8 @@ class RayParams:
lru_evict (bool): Enable LRU eviction if space is needed.
enable_object_reconstruction (bool): Enable plasma reconstruction on
failure.
start_initial_python_workers_for_first_job (bool): If true, start
initial Python workers for the first job on the node.
"""
def __init__(self,
@@ -136,6 +138,7 @@ class RayParams:
include_java=False,
java_worker_options=None,
load_code_from_local=False,
start_initial_python_workers_for_first_job=False,
_internal_config=None,
enable_object_reconstruction=False,
metrics_agent_port=None,
@@ -178,6 +181,8 @@ class RayParams:
self.java_worker_options = java_worker_options
self.load_code_from_local = load_code_from_local
self.metrics_agent_port = metrics_agent_port
self.start_initial_python_workers_for_first_job = (
start_initial_python_workers_for_first_job)
self._internal_config = _internal_config
self._lru_evict = lru_evict
self._enable_object_reconstruction = enable_object_reconstruction
+5 -1
View File
@@ -1271,7 +1271,8 @@ def start_raylet(redis_address,
huge_pages=False,
fate_share=None,
socket_to_use=None,
head_node=False):
head_node=False,
start_initial_python_workers_for_first_job=False):
"""Start a raylet, which is a combined local scheduler and object manager.
Args:
@@ -1403,6 +1404,9 @@ def start_raylet(redis_address,
"--session_dir={}".format(session_dir),
"--metrics-agent-port={}".format(metrics_agent_port),
]
if start_initial_python_workers_for_first_job:
command.append("--num_initial_python_workers_for_first_job={}".format(
resource_spec.num_cpus))
if config.get("plasma_store_as_thread"):
# command related to the plasma store
plasma_directory, object_store_memory = determine_plasma_store_config(
+8
View File
@@ -111,6 +111,14 @@ py_test(
deps = ["//:ray_lib"],
)
py_test(
name = "test_multi_tenancy",
size = "small",
srcs = SRCS + ["test_multi_tenancy.py"],
tags = ["exclusive"],
deps = ["//:ray_lib"],
)
py_test(
name = "test_component_failures",
size = "small",
+116
View File
@@ -0,0 +1,116 @@
# coding: utf-8
import json
import sys
import grpc
import pytest
import ray
import ray.test_utils
from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc
from ray.test_utils import wait_for_condition, run_string_as_driver_nonblocking
# Test that when `redis_address` and `job_config` is not set in
# `ray.init(...)`, Raylet will start `num_cpus` Python workers for the driver.
def test_initial_workers(shutdown_only):
# `num_cpus` should be <=2 because a Travis CI machine only has 2 CPU cores
ray.init(
num_cpus=1,
include_dashboard=True,
_internal_config=json.dumps({
"enable_multi_tenancy": True
}))
raylet = ray.nodes()[0]
raylet_address = "{}:{}".format(raylet["NodeManagerAddress"],
raylet["NodeManagerPort"])
channel = grpc.insecure_channel(raylet_address)
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
wait_for_condition(lambda: len([
worker for worker in stub.GetNodeStats(
node_manager_pb2.GetNodeStatsRequest()).workers_stats
if not worker.is_driver
]) == 1,
timeout=10)
# This test case starts some driver processes. Each driver process submits
# some tasks and collect the PIDs of the workers used by the driver. The
# drivers output the PID list which will be read by the test case itself. The
# test case will compare the PIDs used by different drivers and make sure that
# all the PIDs don't overlap. If overlapped, it means that tasks owned by
# different drivers were scheduled to the same worker process, that is, tasks
# of different jobs were not correctly isolated during execution.
def test_multi_drivers(shutdown_only):
info = ray.init(
_internal_config=json.dumps({
"enable_multi_tenancy": True
}))
driver_code = """
import os
import sys
import ray
ray.init(address="{}")
@ray.remote
class Actor:
def get_pid(self):
return os.getpid()
@ray.remote
def get_pid():
return os.getpid()
pid_objs = []
# Submit some normal tasks and get the PIDs of workers which execute the tasks.
pid_objs = pid_objs + [get_pid.remote() for _ in range(5)]
# Create some actors and get the PIDs of actors.
actors = [Actor.remote() for _ in range(5)]
pid_objs = pid_objs + [actor.get_pid.remote() for actor in actors]
pids = set([ray.get(obj) for obj in pid_objs])
# Write pids to stdout
print("PID:" + str.join(",", [str(_) for _ in pids]))
ray.shutdown()
""".format(info["redis_address"])
driver_count = 10
processes = [
run_string_as_driver_nonblocking(driver_code)
for _ in range(driver_count)
]
outputs = []
for p in processes:
out = p.stdout.read().decode("ascii")
err = p.stderr.read().decode("ascii")
p.wait()
# out, err = p.communicate()
# out = ray.utils.decode(out)
# err = ray.utils.decode(err)
if p.returncode != 0:
print("Driver with PID {} returned error code {}".format(
p.pid, p.returncode))
print("STDOUT:\n{}".format(out))
print("STDERR:\n{}".format(err))
outputs.append((p, out))
all_worker_pids = set()
for p, out in outputs:
assert p.returncode == 0
for line in out.splitlines():
if line.startswith("PID:"):
worker_pids = [int(_) for _ in line.split(":")[1].split(",")]
assert len(worker_pids) > 0
for worker_pid in worker_pids:
assert worker_pid not in all_worker_pids, (
("Worker process with PID {} is shared" +
" by multiple drivers.").format(worker_pid))
all_worker_pids.add(worker_pid)
if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
+13 -3
View File
@@ -20,6 +20,7 @@ import ray.cloudpickle as pickle
import ray.gcs_utils
import ray.memory_monitor as memory_monitor
import ray.node
import ray.job_config
import ray.parameter
import ray.ray_constants as ray_constants
import ray.remote_function
@@ -486,6 +487,7 @@ def init(address=None,
dashboard_host="localhost",
dashboard_port=ray_constants.DEFAULT_DASHBOARD_PORT,
job_id=None,
job_config=None,
configure_logging=True,
logging_level=logging.INFO,
logging_format=ray_constants.LOGGER_FORMAT,
@@ -588,6 +590,7 @@ def init(address=None,
dashboard_port: The port to bind the dashboard server to. Defaults to
8265.
job_id: The ID of this job.
job_config (ray.job_config.JobConfig): The job configuration.
configure_logging: True (default) if configuration of logging is
allowed here. Otherwise, the user may want to configure it
separately.
@@ -713,6 +716,7 @@ def init(address=None,
temp_dir=temp_dir,
load_code_from_local=load_code_from_local,
java_worker_options=java_worker_options,
start_initial_python_workers_for_first_job=True,
_internal_config=_internal_config,
lru_evict=lru_evict,
enable_object_reconstruction=enable_object_reconstruction)
@@ -803,7 +807,8 @@ def init(address=None,
log_to_driver=log_to_driver,
worker=global_worker,
driver_object_store_memory=driver_object_store_memory,
job_id=job_id)
job_id=job_id,
job_config=job_config)
for hook in _post_init_hooks:
hook()
@@ -1147,7 +1152,8 @@ def connect(node,
log_to_driver=False,
worker=global_worker,
driver_object_store_memory=None,
job_id=None):
job_id=None,
job_config=None):
"""Connect this worker to the raylet, to Plasma, and to Redis.
Args:
@@ -1160,6 +1166,7 @@ def connect(node,
driver_object_store_memory: Limit the amount of memory the driver can
use in the object store when creating objects.
job_id: The ID of job. If it's None, then we will generate one.
job_config (ray.job_config.JobConfig): The job configuration.
"""
# Do some basic checking to make sure we didn't call ray.init twice.
error_message = "Perhaps you called ray.init twice by accident?"
@@ -1265,7 +1272,9 @@ def connect(node,
int(redis_port),
node.redis_password,
)
if job_config is None:
job_config = ray.job_config.JobConfig()
serialized_job_config = job_config.serialize()
worker.core_worker = ray._raylet.CoreWorker(
(mode == SCRIPT_MODE or mode == LOCAL_MODE),
node.plasma_store_socket_name,
@@ -1280,6 +1289,7 @@ def connect(node,
driver_name,
log_stdout_file_path,
log_stderr_file_path,
serialized_job_config,
)
# Create an object for interfacing with the global state.