mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:54:34 +08:00
[Core] Multi-tenancy: Job isolation & implement per job config (except for env variables) (#9500)
This commit is contained in:
@@ -658,7 +658,8 @@ cdef class CoreWorker:
|
||||
def __cinit__(self, is_driver, store_socket, raylet_socket,
|
||||
JobID job_id, GcsClientOptions gcs_options, log_dir,
|
||||
node_ip_address, node_manager_port, raylet_ip_address,
|
||||
local_mode, driver_name, stdout_file, stderr_file):
|
||||
local_mode, driver_name, stdout_file, stderr_file,
|
||||
serialized_job_config):
|
||||
self.is_driver = is_driver
|
||||
self.is_local_mode = local_mode
|
||||
|
||||
@@ -688,6 +689,7 @@ cdef class CoreWorker:
|
||||
options.num_workers = 1
|
||||
options.kill_main = kill_main_task
|
||||
options.terminate_asyncio_thread = terminate_asyncio_thread
|
||||
options.serialized_job_config = serialized_job_config
|
||||
|
||||
CCoreWorkerProcess.Initialize(options)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from ray.core.generated.gcs_pb2 import (
|
||||
ActorTableData,
|
||||
GcsNodeInfo,
|
||||
JobTableData,
|
||||
JobConfig,
|
||||
ErrorTableData,
|
||||
ErrorType,
|
||||
GcsEntry,
|
||||
@@ -25,6 +26,7 @@ __all__ = [
|
||||
"ActorTableData",
|
||||
"GcsNodeInfo",
|
||||
"JobTableData",
|
||||
"JobConfig",
|
||||
"ErrorTableData",
|
||||
"ErrorType",
|
||||
"GcsEntry",
|
||||
|
||||
@@ -226,6 +226,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
(c_bool() nogil) kill_main
|
||||
CCoreWorkerOptions()
|
||||
(void() nogil) terminate_asyncio_thread
|
||||
c_string serialized_job_config
|
||||
|
||||
cdef cppclass CCoreWorkerProcess "ray::CoreWorkerProcess":
|
||||
@staticmethod
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
import ray
|
||||
|
||||
|
||||
class JobConfig:
|
||||
"""A class used to store the configurations of a job.
|
||||
|
||||
Attributes:
|
||||
worker_env (dict): Environment variables to be set on worker
|
||||
processes.
|
||||
num_java_workers_per_process (int): The number of java workers per
|
||||
worker process.
|
||||
jvm_options (str[]): The jvm options for java workers of the job.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_env=None,
|
||||
num_java_workers_per_process=10,
|
||||
jvm_options=None,
|
||||
):
|
||||
if worker_env is None:
|
||||
self.worker_env = dict()
|
||||
else:
|
||||
self.worker_env = worker_env
|
||||
self.num_java_workers_per_process = num_java_workers_per_process
|
||||
if jvm_options is None:
|
||||
self.jvm_options = []
|
||||
else:
|
||||
self.jvm_options = jvm_options
|
||||
|
||||
def serialize(self):
|
||||
job_config = ray.gcs_utils.JobConfig()
|
||||
for key in self.worker_env:
|
||||
job_config.worker_env[key] = self.worker_env[key]
|
||||
job_config.num_java_workers_per_process = (
|
||||
self.num_java_workers_per_process)
|
||||
job_config.jvm_options.extend(self.jvm_options)
|
||||
return job_config.SerializeToString()
|
||||
+3
-1
@@ -679,7 +679,9 @@ class Node:
|
||||
huge_pages=self._ray_params.huge_pages,
|
||||
fate_share=self.kernel_fate_share,
|
||||
socket_to_use=self.socket,
|
||||
head_node=self.head)
|
||||
head_node=self.head,
|
||||
start_initial_python_workers_for_first_job=self._ray_params.
|
||||
start_initial_python_workers_for_first_job)
|
||||
assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes
|
||||
self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info]
|
||||
|
||||
|
||||
@@ -94,6 +94,8 @@ class RayParams:
|
||||
lru_evict (bool): Enable LRU eviction if space is needed.
|
||||
enable_object_reconstruction (bool): Enable plasma reconstruction on
|
||||
failure.
|
||||
start_initial_python_workers_for_first_job (bool): If true, start
|
||||
initial Python workers for the first job on the node.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -136,6 +138,7 @@ class RayParams:
|
||||
include_java=False,
|
||||
java_worker_options=None,
|
||||
load_code_from_local=False,
|
||||
start_initial_python_workers_for_first_job=False,
|
||||
_internal_config=None,
|
||||
enable_object_reconstruction=False,
|
||||
metrics_agent_port=None,
|
||||
@@ -178,6 +181,8 @@ class RayParams:
|
||||
self.java_worker_options = java_worker_options
|
||||
self.load_code_from_local = load_code_from_local
|
||||
self.metrics_agent_port = metrics_agent_port
|
||||
self.start_initial_python_workers_for_first_job = (
|
||||
start_initial_python_workers_for_first_job)
|
||||
self._internal_config = _internal_config
|
||||
self._lru_evict = lru_evict
|
||||
self._enable_object_reconstruction = enable_object_reconstruction
|
||||
|
||||
@@ -1271,7 +1271,8 @@ def start_raylet(redis_address,
|
||||
huge_pages=False,
|
||||
fate_share=None,
|
||||
socket_to_use=None,
|
||||
head_node=False):
|
||||
head_node=False,
|
||||
start_initial_python_workers_for_first_job=False):
|
||||
"""Start a raylet, which is a combined local scheduler and object manager.
|
||||
|
||||
Args:
|
||||
@@ -1403,6 +1404,9 @@ def start_raylet(redis_address,
|
||||
"--session_dir={}".format(session_dir),
|
||||
"--metrics-agent-port={}".format(metrics_agent_port),
|
||||
]
|
||||
if start_initial_python_workers_for_first_job:
|
||||
command.append("--num_initial_python_workers_for_first_job={}".format(
|
||||
resource_spec.num_cpus))
|
||||
if config.get("plasma_store_as_thread"):
|
||||
# command related to the plasma store
|
||||
plasma_directory, object_store_memory = determine_plasma_store_config(
|
||||
|
||||
@@ -111,6 +111,14 @@ py_test(
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_multi_tenancy",
|
||||
size = "small",
|
||||
srcs = SRCS + ["test_multi_tenancy.py"],
|
||||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_component_failures",
|
||||
size = "small",
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
# coding: utf-8
|
||||
import json
|
||||
import sys
|
||||
|
||||
import grpc
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
import ray.test_utils
|
||||
from ray.core.generated import node_manager_pb2, node_manager_pb2_grpc
|
||||
from ray.test_utils import wait_for_condition, run_string_as_driver_nonblocking
|
||||
|
||||
|
||||
# Test that when `redis_address` and `job_config` is not set in
|
||||
# `ray.init(...)`, Raylet will start `num_cpus` Python workers for the driver.
|
||||
def test_initial_workers(shutdown_only):
|
||||
# `num_cpus` should be <=2 because a Travis CI machine only has 2 CPU cores
|
||||
ray.init(
|
||||
num_cpus=1,
|
||||
include_dashboard=True,
|
||||
_internal_config=json.dumps({
|
||||
"enable_multi_tenancy": True
|
||||
}))
|
||||
raylet = ray.nodes()[0]
|
||||
raylet_address = "{}:{}".format(raylet["NodeManagerAddress"],
|
||||
raylet["NodeManagerPort"])
|
||||
channel = grpc.insecure_channel(raylet_address)
|
||||
stub = node_manager_pb2_grpc.NodeManagerServiceStub(channel)
|
||||
wait_for_condition(lambda: len([
|
||||
worker for worker in stub.GetNodeStats(
|
||||
node_manager_pb2.GetNodeStatsRequest()).workers_stats
|
||||
if not worker.is_driver
|
||||
]) == 1,
|
||||
timeout=10)
|
||||
|
||||
|
||||
# This test case starts some driver processes. Each driver process submits
|
||||
# some tasks and collect the PIDs of the workers used by the driver. The
|
||||
# drivers output the PID list which will be read by the test case itself. The
|
||||
# test case will compare the PIDs used by different drivers and make sure that
|
||||
# all the PIDs don't overlap. If overlapped, it means that tasks owned by
|
||||
# different drivers were scheduled to the same worker process, that is, tasks
|
||||
# of different jobs were not correctly isolated during execution.
|
||||
def test_multi_drivers(shutdown_only):
|
||||
info = ray.init(
|
||||
_internal_config=json.dumps({
|
||||
"enable_multi_tenancy": True
|
||||
}))
|
||||
|
||||
driver_code = """
|
||||
import os
|
||||
import sys
|
||||
import ray
|
||||
|
||||
|
||||
ray.init(address="{}")
|
||||
|
||||
@ray.remote
|
||||
class Actor:
|
||||
def get_pid(self):
|
||||
return os.getpid()
|
||||
|
||||
@ray.remote
|
||||
def get_pid():
|
||||
return os.getpid()
|
||||
|
||||
pid_objs = []
|
||||
# Submit some normal tasks and get the PIDs of workers which execute the tasks.
|
||||
pid_objs = pid_objs + [get_pid.remote() for _ in range(5)]
|
||||
# Create some actors and get the PIDs of actors.
|
||||
actors = [Actor.remote() for _ in range(5)]
|
||||
pid_objs = pid_objs + [actor.get_pid.remote() for actor in actors]
|
||||
|
||||
pids = set([ray.get(obj) for obj in pid_objs])
|
||||
# Write pids to stdout
|
||||
print("PID:" + str.join(",", [str(_) for _ in pids]))
|
||||
|
||||
ray.shutdown()
|
||||
""".format(info["redis_address"])
|
||||
|
||||
driver_count = 10
|
||||
processes = [
|
||||
run_string_as_driver_nonblocking(driver_code)
|
||||
for _ in range(driver_count)
|
||||
]
|
||||
outputs = []
|
||||
for p in processes:
|
||||
out = p.stdout.read().decode("ascii")
|
||||
err = p.stderr.read().decode("ascii")
|
||||
p.wait()
|
||||
# out, err = p.communicate()
|
||||
# out = ray.utils.decode(out)
|
||||
# err = ray.utils.decode(err)
|
||||
if p.returncode != 0:
|
||||
print("Driver with PID {} returned error code {}".format(
|
||||
p.pid, p.returncode))
|
||||
print("STDOUT:\n{}".format(out))
|
||||
print("STDERR:\n{}".format(err))
|
||||
outputs.append((p, out))
|
||||
|
||||
all_worker_pids = set()
|
||||
for p, out in outputs:
|
||||
assert p.returncode == 0
|
||||
for line in out.splitlines():
|
||||
if line.startswith("PID:"):
|
||||
worker_pids = [int(_) for _ in line.split(":")[1].split(",")]
|
||||
assert len(worker_pids) > 0
|
||||
for worker_pid in worker_pids:
|
||||
assert worker_pid not in all_worker_pids, (
|
||||
("Worker process with PID {} is shared" +
|
||||
" by multiple drivers.").format(worker_pid))
|
||||
all_worker_pids.add(worker_pid)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
+13
-3
@@ -20,6 +20,7 @@ import ray.cloudpickle as pickle
|
||||
import ray.gcs_utils
|
||||
import ray.memory_monitor as memory_monitor
|
||||
import ray.node
|
||||
import ray.job_config
|
||||
import ray.parameter
|
||||
import ray.ray_constants as ray_constants
|
||||
import ray.remote_function
|
||||
@@ -486,6 +487,7 @@ def init(address=None,
|
||||
dashboard_host="localhost",
|
||||
dashboard_port=ray_constants.DEFAULT_DASHBOARD_PORT,
|
||||
job_id=None,
|
||||
job_config=None,
|
||||
configure_logging=True,
|
||||
logging_level=logging.INFO,
|
||||
logging_format=ray_constants.LOGGER_FORMAT,
|
||||
@@ -588,6 +590,7 @@ def init(address=None,
|
||||
dashboard_port: The port to bind the dashboard server to. Defaults to
|
||||
8265.
|
||||
job_id: The ID of this job.
|
||||
job_config (ray.job_config.JobConfig): The job configuration.
|
||||
configure_logging: True (default) if configuration of logging is
|
||||
allowed here. Otherwise, the user may want to configure it
|
||||
separately.
|
||||
@@ -713,6 +716,7 @@ def init(address=None,
|
||||
temp_dir=temp_dir,
|
||||
load_code_from_local=load_code_from_local,
|
||||
java_worker_options=java_worker_options,
|
||||
start_initial_python_workers_for_first_job=True,
|
||||
_internal_config=_internal_config,
|
||||
lru_evict=lru_evict,
|
||||
enable_object_reconstruction=enable_object_reconstruction)
|
||||
@@ -803,7 +807,8 @@ def init(address=None,
|
||||
log_to_driver=log_to_driver,
|
||||
worker=global_worker,
|
||||
driver_object_store_memory=driver_object_store_memory,
|
||||
job_id=job_id)
|
||||
job_id=job_id,
|
||||
job_config=job_config)
|
||||
|
||||
for hook in _post_init_hooks:
|
||||
hook()
|
||||
@@ -1147,7 +1152,8 @@ def connect(node,
|
||||
log_to_driver=False,
|
||||
worker=global_worker,
|
||||
driver_object_store_memory=None,
|
||||
job_id=None):
|
||||
job_id=None,
|
||||
job_config=None):
|
||||
"""Connect this worker to the raylet, to Plasma, and to Redis.
|
||||
|
||||
Args:
|
||||
@@ -1160,6 +1166,7 @@ def connect(node,
|
||||
driver_object_store_memory: Limit the amount of memory the driver can
|
||||
use in the object store when creating objects.
|
||||
job_id: The ID of job. If it's None, then we will generate one.
|
||||
job_config (ray.job_config.JobConfig): The job configuration.
|
||||
"""
|
||||
# Do some basic checking to make sure we didn't call ray.init twice.
|
||||
error_message = "Perhaps you called ray.init twice by accident?"
|
||||
@@ -1265,7 +1272,9 @@ def connect(node,
|
||||
int(redis_port),
|
||||
node.redis_password,
|
||||
)
|
||||
|
||||
if job_config is None:
|
||||
job_config = ray.job_config.JobConfig()
|
||||
serialized_job_config = job_config.serialize()
|
||||
worker.core_worker = ray._raylet.CoreWorker(
|
||||
(mode == SCRIPT_MODE or mode == LOCAL_MODE),
|
||||
node.plasma_store_socket_name,
|
||||
@@ -1280,6 +1289,7 @@ def connect(node,
|
||||
driver_name,
|
||||
log_stdout_file_path,
|
||||
log_stderr_file_path,
|
||||
serialized_job_config,
|
||||
)
|
||||
|
||||
# Create an object for interfacing with the global state.
|
||||
|
||||
Reference in New Issue
Block a user