Add password authentication to Redis ports (#2952)

* Implement Redis authentication

* Throw exception for legacy Ray

* Add test

* Formatting

* Fix bugs in CLI

* Fix bugs in Raylet

* Move default password to constants.h

* Use pytest.fixture

* Fix bug

* Authenticate using formatted strings

* Add missing passwords

* Add test

* Improve authentication of async contexts

* Disable Redis authentication for credis

* Update test for credis

* Fix rebase artifacts

* Fix formatting

* Add workaround for issue #3045

* Increase timeout for test

* Improve C++ readability

* Fixes for CLI

* Add security docs

* Address comments

* Address comments

* Adress comments

* Use ray.get

* Fix lint
This commit is contained in:
Peter Schafhalter
2018-10-16 22:48:30 -07:00
committed by Philipp Moritz
parent a9e454f6fd
commit a41bbc10ef
22 changed files with 462 additions and 115 deletions
+7 -2
View File
@@ -78,6 +78,7 @@ class GlobalState(object):
def _initialize_global_state(self,
redis_ip_address,
redis_port,
redis_password=None,
timeout=20):
"""Initialize the GlobalState object by connecting to Redis.
@@ -89,9 +90,10 @@ class GlobalState(object):
redis_ip_address: The IP address of the node that the Redis server
lives on.
redis_port: The port that the Redis server is listening on.
redis_password: The password of the redis server.
"""
self.redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port)
host=redis_ip_address, port=redis_port, password=redis_password)
start_time = time.time()
@@ -143,7 +145,10 @@ class GlobalState(object):
for ip_address_port in ip_address_ports:
shard_address, shard_port = ip_address_port.split(b":")
self.redis_clients.append(
redis.StrictRedis(host=shard_address, port=shard_port))
redis.StrictRedis(
host=shard_address,
port=shard_port,
password=redis_password))
def _execute_command(self, key, *args):
"""Execute a Redis command on the appropriate Redis shard based on key.
+17 -4
View File
@@ -35,11 +35,15 @@ class LogMonitor(object):
handle for that file.
"""
def __init__(self, redis_ip_address, redis_port, node_ip_address):
def __init__(self,
redis_ip_address,
redis_port,
node_ip_address,
redis_password=None):
"""Initialize the log monitor object."""
self.node_ip_address = node_ip_address
self.redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port)
host=redis_ip_address, port=redis_port, password=redis_password)
self.log_files = {}
self.log_file_handles = {}
self.files_to_ignore = set()
@@ -130,6 +134,12 @@ if __name__ == "__main__":
required=True,
type=str,
help="The IP address of the node this process is on.")
parser.add_argument(
"--redis-password",
required=False,
type=str,
default=None,
help="the password to use for Redis")
parser.add_argument(
"--logging-level",
required=False,
@@ -151,6 +161,9 @@ if __name__ == "__main__":
redis_ip_address = get_ip_address(args.redis_address)
redis_port = get_port(args.redis_address)
log_monitor = LogMonitor(redis_ip_address, redis_port,
args.node_ip_address)
log_monitor = LogMonitor(
redis_ip_address,
redis_port,
args.node_ip_address,
redis_password=args.redis_password)
log_monitor.run()
+22 -5
View File
@@ -70,13 +70,18 @@ class Monitor(object):
managers that were up at one point and have died since then.
"""
def __init__(self, redis_address, redis_port, autoscaling_config):
def __init__(self,
redis_address,
redis_port,
autoscaling_config,
redis_password=None):
# Initialize the Redis clients.
self.state = ray.experimental.state.GlobalState()
self.state._initialize_global_state(redis_address, redis_port)
self.state._initialize_global_state(
redis_address, redis_port, redis_password=redis_password)
self.use_raylet = self.state.use_raylet
self.redis = redis.StrictRedis(
host=redis_address, port=redis_port, db=0)
host=redis_address, port=redis_port, db=0, password=redis_password)
# Setup subscriptions to the primary Redis server and the Redis shards.
self.primary_subscribe_client = self.redis.pubsub(
ignore_subscribe_messages=True)
@@ -118,7 +123,9 @@ class Monitor(object):
else:
addr_port = addr_port[0].split(b":")
self.redis_shard = redis.StrictRedis(
host=addr_port[0], port=addr_port[1])
host=addr_port[0],
port=addr_port[1],
password=redis_password)
try:
self.redis_shard.execute_command("HEAD.FLUSH 0")
except redis.exceptions.ResponseError as e:
@@ -773,6 +780,12 @@ if __name__ == "__main__":
required=False,
type=str,
help="the path to the autoscaling config file")
parser.add_argument(
"--redis-password",
required=False,
type=str,
default=None,
help="the password to use for Redis")
parser.add_argument(
"--logging-level",
required=False,
@@ -798,7 +811,11 @@ if __name__ == "__main__":
else:
autoscaling_config = None
monitor = Monitor(redis_ip_address, redis_port, autoscaling_config)
monitor = Monitor(
redis_ip_address,
redis_port,
autoscaling_config,
redis_password=args.redis_password)
try:
monitor.run()
+29 -10
View File
@@ -89,6 +89,11 @@ def cli(logging_level, logging_format):
type=int,
help=("If provided, attempt to configure Redis with this "
"maximum number of clients."))
@click.option(
"--redis-password",
required=False,
type=str,
help="If provided, secure Redis ports with this password")
@click.option(
"--redis-shard-ports",
required=False,
@@ -190,10 +195,11 @@ def cli(logging_level, logging_format):
default=None,
help="manually specify the root temporary dir of the Ray process")
def start(node_ip_address, redis_address, redis_port, num_redis_shards,
redis_max_clients, redis_shard_ports, object_manager_port,
object_store_memory, num_workers, num_cpus, num_gpus, resources,
head, no_ui, block, plasma_directory, huge_pages, autoscaling_config,
use_raylet, no_redirect_worker_output, no_redirect_output,
redis_max_clients, redis_password, redis_shard_ports,
object_manager_port, object_store_memory, num_workers, num_cpus,
num_gpus, resources, head, no_ui, block, plasma_directory,
huge_pages, autoscaling_config, use_raylet,
no_redirect_worker_output, no_redirect_output,
plasma_store_socket_name, raylet_socket_name, temp_dir):
# Convert hostnames to numerical IP address.
if node_ip_address is not None:
@@ -205,6 +211,11 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
# This environment variable is used in our testing setup.
logger.info("Detected environment variable 'RAY_USE_XRAY'.")
use_raylet = True
if not use_raylet and redis_password is not None:
raise Exception("Setting the 'redis-password' argument is not "
"supported in legacy Ray. To run Ray with "
"password-protected Redis ports, pass "
"the '--use-raylet' flag.")
try:
resources = json.loads(resources)
@@ -269,6 +280,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
num_redis_shards=num_redis_shards,
redis_max_clients=redis_max_clients,
redis_protected_mode=False,
redis_password=redis_password,
include_webui=(not no_ui),
plasma_directory=plasma_directory,
huge_pages=huge_pages,
@@ -281,16 +293,20 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
logger.info(
"\nStarted Ray on this node. You can add additional nodes to "
"the cluster by calling\n\n"
" ray start --redis-address {}\n\n"
" ray start --redis-address {}{}{}\n\n"
"from the node you wish to add. You can connect a driver to the "
"cluster from Python by running\n\n"
" import ray\n"
" ray.init(redis_address=\"{}\")\n\n"
" ray.init(redis_address=\"{}{}{}\")\n\n"
"If you have trouble connecting from a different machine, check "
"that your firewall is configured properly. If you wish to "
"terminate the processes that have been started, run\n\n"
" ray stop".format(address_info["redis_address"],
address_info["redis_address"]))
" ray stop".format(
address_info["redis_address"], " --redis-password "
if redis_password else "", redis_password if redis_password
else "", address_info["redis_address"], "\", redis_password=\""
if redis_password else "", redis_password
if redis_password else ""))
else:
# Start Ray on a non-head node.
if redis_port is not None:
@@ -315,10 +331,12 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
# Wait for the Redis server to be started. And throw an exception if we
# can't connect to it.
services.wait_for_redis_to_start(redis_ip_address, int(redis_port))
services.wait_for_redis_to_start(
redis_ip_address, int(redis_port), password=redis_password)
# Create a Redis client.
redis_client = services.create_redis_client(redis_address)
redis_client = services.create_redis_client(
redis_address, password=redis_password)
# Check that the verion information on this node matches the version
# information that the cluster was started with.
@@ -339,6 +357,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards,
object_manager_ports=[object_manager_port],
num_workers=num_workers,
object_store_memory=object_store_memory,
redis_password=redis_password,
cleanup=False,
redirect_worker_output=not no_redirect_worker_output,
redirect_output=not no_redirect_output,
+136 -52
View File
@@ -261,7 +261,10 @@ def get_node_ip_address(address="8.8.8.8:53"):
return node_ip_address
def record_log_files_in_redis(redis_address, node_ip_address, log_files):
def record_log_files_in_redis(redis_address,
node_ip_address,
log_files,
password=None):
"""Record in Redis that a new log file has been created.
This is used so that each log monitor can check Redis and figure out which
@@ -273,23 +276,24 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files):
on.
log_files: A list of file handles for the log files. If one of the file
handles is None, we ignore it.
password (str): The password of the redis server.
"""
for log_file in log_files:
if log_file is not None:
redis_ip_address, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port)
host=redis_ip_address, port=redis_port, password=password)
# The name of the key storing the list of log filenames for this IP
# address.
log_file_list_key = "LOG_FILENAMES:{}".format(node_ip_address)
redis_client.rpush(log_file_list_key, log_file.name)
def create_redis_client(redis_address):
def create_redis_client(redis_address, password=None):
"""Create a Redis client.
Args:
The IP address and port of the Redis server.
The IP address, port, and password of the Redis server.
Returns:
A Redis client.
@@ -297,10 +301,14 @@ def create_redis_client(redis_address):
redis_ip_address, redis_port = redis_address.split(":")
# For this command to work, some other client (on the same machine
# as Redis) must have run "CONFIG SET protected-mode no".
return redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
return redis.StrictRedis(
host=redis_ip_address, port=int(redis_port), password=password)
def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
def wait_for_redis_to_start(redis_ip_address,
redis_port,
password=None,
num_retries=5):
"""Wait for a Redis server to be available.
This is accomplished by creating a Redis client and sending a random
@@ -309,13 +317,15 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
Args:
redis_ip_address (str): The IP address of the redis server.
redis_port (int): The port of the redis server.
password (str): The password of the redis server.
num_retries (int): The number of times to try connecting with redis.
The client will sleep for one second between attempts.
Raises:
Exception: An exception is raised if we could not connect with Redis.
"""
redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port)
redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port, password=password)
# Wait for the Redis server to start.
counter = 0
while counter < num_retries:
@@ -425,6 +435,7 @@ def start_redis(node_ip_address,
redirect_worker_output=False,
cleanup=True,
protected_mode=False,
password=None,
use_credis=None):
"""Start the Redis global state store.
@@ -451,6 +462,8 @@ def start_redis(node_ip_address,
then all Redis processes started by this method will be killed by
services.cleanup() when the Python process that imported services
exits.
password (str): Prevents external clients without the password
from connecting to Redis if provided.
use_credis: If True, additionally load the chain-replicated libraries
into the redis servers. Defaults to None, which means its value is
set by the presence of "RAY_USE_NEW_GCS" in os.environ.
@@ -469,6 +482,13 @@ def start_redis(node_ip_address,
if use_credis is None:
use_credis = ("RAY_USE_NEW_GCS" in os.environ)
if use_credis and password is not None:
# TODO(pschafhalter) remove this once credis supports
# authenticating Redis ports
raise Exception("Setting the `redis_password` argument is not "
"supported in credis. To run Ray with "
"password-protected Redis ports, ensure that "
"the environment variable `RAY_USE_NEW_GCS=off`.")
if not use_credis:
assigned_port, _ = _start_redis_instance(
node_ip_address=node_ip_address,
@@ -477,7 +497,8 @@ def start_redis(node_ip_address,
stdout_file=redis_stdout_file,
stderr_file=redis_stderr_file,
cleanup=cleanup,
protected_mode=protected_mode)
protected_mode=protected_mode,
password=password)
else:
assigned_port, _ = _start_redis_instance(
node_ip_address=node_ip_address,
@@ -491,20 +512,23 @@ def start_redis(node_ip_address,
# It is important to load the credis module BEFORE the ray module,
# as the latter contains an extern declaration that the former
# supplies.
modules=[CREDIS_MASTER_MODULE, REDIS_MODULE])
modules=[CREDIS_MASTER_MODULE, REDIS_MODULE],
password=password)
if port is not None:
assert assigned_port == port
port = assigned_port
redis_address = address(node_ip_address, port)
redis_client = redis.StrictRedis(host=node_ip_address, port=port)
redis_client = redis.StrictRedis(
host=node_ip_address, port=port, password=password)
# Store whether we're using the raylet code path or not.
redis_client.set("UseRaylet", 1 if use_raylet else 0)
# Register the number of Redis shards in the primary shard, so that clients
# know how many redis shards to expect under RedisShards.
primary_redis_client = redis.StrictRedis(host=node_ip_address, port=port)
primary_redis_client = redis.StrictRedis(
host=node_ip_address, port=port, password=password)
primary_redis_client.set("NumRedisShards", str(num_redis_shards))
# Put the redirect_worker_output bool in the Redis shard so that workers
@@ -529,7 +553,8 @@ def start_redis(node_ip_address,
stdout_file=redis_stdout_file,
stderr_file=redis_stderr_file,
cleanup=cleanup,
protected_mode=protected_mode)
protected_mode=protected_mode,
password=password)
else:
assert num_redis_shards == 1, \
"For now, RAY_USE_NEW_GCS supports 1 shard, and credis "\
@@ -542,6 +567,7 @@ def start_redis(node_ip_address,
stderr_file=redis_stderr_file,
cleanup=cleanup,
protected_mode=protected_mode,
password=password,
executable=CREDIS_EXECUTABLE,
# It is important to load the credis module BEFORE the ray
# module, as the latter contains an extern declaration that the
@@ -557,7 +583,7 @@ def start_redis(node_ip_address,
if use_credis:
shard_client = redis.StrictRedis(
host=node_ip_address, port=redis_shard_port)
host=node_ip_address, port=redis_shard_port, password=password)
# Configure the chain state.
primary_redis_client.execute_command("MASTER.ADD", node_ip_address,
redis_shard_port)
@@ -591,6 +617,7 @@ def _start_redis_instance(node_ip_address="127.0.0.1",
stderr_file=None,
cleanup=True,
protected_mode=False,
password=None,
executable=REDIS_EXECUTABLE,
modules=None):
"""Start a single Redis server.
@@ -614,6 +641,8 @@ def _start_redis_instance(node_ip_address="127.0.0.1",
mode. This will prevent clients on other machines from connecting
and is only used when the Redis servers are started via ray.init()
as opposed to ray start.
password (str): Prevents external clients without the password
from connecting to Redis if provided.
executable (str): Full path tho the redis-server executable.
modules (list of str): A list of pathnames, pointing to the redis
module(s) that will be loaded in this redis server. If None, load
@@ -654,6 +683,8 @@ def _start_redis_instance(node_ip_address="127.0.0.1",
command = [executable]
if protected_mode:
command += [redis_config_filename]
if password:
command += ["--requirepass", password]
command += (
["--port", str(port), "--loglevel", "warning"] + load_module_args)
@@ -672,9 +703,10 @@ def _start_redis_instance(node_ip_address="127.0.0.1",
stdout_file.name, stderr_file.name))
# Create a Redis client just for configuring Redis.
redis_client = redis.StrictRedis(host="127.0.0.1", port=port)
redis_client = redis.StrictRedis(
host="127.0.0.1", port=port, password=password)
# Wait for the Redis server to start.
wait_for_redis_to_start("127.0.0.1", port)
wait_for_redis_to_start("127.0.0.1", port, password=password)
# Configure Redis to generate keyspace notifications. TODO(rkn): Change
# this to only generate notifications for the export keys.
redis_client.config_set("notify-keyspace-events", "Kl")
@@ -719,8 +751,9 @@ def _start_redis_instance(node_ip_address="127.0.0.1",
redis_client.set("redis_start_time", time.time())
# Record the log files in Redis.
record_log_files_in_redis(
address(node_ip_address, port), node_ip_address,
[stdout_file, stderr_file])
address(node_ip_address, port),
node_ip_address, [stdout_file, stderr_file],
password=password)
return port, p
@@ -728,7 +761,8 @@ def start_log_monitor(redis_address,
node_ip_address,
stdout_file=None,
stderr_file=None,
cleanup=cleanup):
cleanup=cleanup,
redis_password=None):
"""Start a log monitor process.
Args:
@@ -742,27 +776,31 @@ def start_log_monitor(redis_address,
cleanup (bool): True if using Ray in local mode. If cleanup is true,
then this process will be killed by services.cleanup() when the
Python process that imported services exits.
redis_password (str): The password of the redis server.
"""
log_monitor_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "log_monitor.py")
p = subprocess.Popen(
[
sys.executable, "-u", log_monitor_filepath, "--redis-address",
redis_address, "--node-ip-address", node_ip_address
],
stdout=stdout_file,
stderr=stderr_file)
command = [
sys.executable, "-u", log_monitor_filepath, "--redis-address",
redis_address, "--node-ip-address", node_ip_address
]
if redis_password:
command += ["--redis-password", redis_password]
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
if cleanup:
all_processes[PROCESS_TYPE_LOG_MONITOR].append(p)
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
record_log_files_in_redis(
redis_address,
node_ip_address, [stdout_file, stderr_file],
password=redis_password)
def start_global_scheduler(redis_address,
node_ip_address,
stdout_file=None,
stderr_file=None,
cleanup=True):
cleanup=True,
redis_password=None):
"""Start a global scheduler process.
Args:
@@ -776,6 +814,7 @@ def start_global_scheduler(redis_address,
cleanup (bool): True if using Ray in local mode. If cleanup is true,
then this process will be killed by services.cleanup() when the
Python process that imported services exits.
redis_password (str): The password of the redis server.
"""
p = global_scheduler.start_global_scheduler(
redis_address,
@@ -784,8 +823,10 @@ def start_global_scheduler(redis_address,
stderr_file=stderr_file)
if cleanup:
all_processes[PROCESS_TYPE_GLOBAL_SCHEDULER].append(p)
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
record_log_files_in_redis(
redis_address,
node_ip_address, [stdout_file, stderr_file],
password=redis_password)
def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True):
@@ -911,7 +952,8 @@ def start_local_scheduler(redis_address,
stderr_file=None,
cleanup=True,
resources=None,
num_workers=0):
num_workers=0,
redis_password=None):
"""Start a local scheduler process.
Args:
@@ -935,6 +977,7 @@ def start_local_scheduler(redis_address,
quantity of that resource.
num_workers (int): The number of workers that the local scheduler
should start.
redis_password (str): The password of the redis server.
Return:
The name of the local scheduler socket.
@@ -957,8 +1000,10 @@ def start_local_scheduler(redis_address,
num_workers=num_workers)
if cleanup:
all_processes[PROCESS_TYPE_LOCAL_SCHEDULER].append(p)
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
record_log_files_in_redis(
redis_address,
node_ip_address, [stdout_file, stderr_file],
password=redis_password)
return local_scheduler_name
@@ -973,7 +1018,8 @@ def start_raylet(redis_address,
use_profiler=False,
stdout_file=None,
stderr_file=None,
cleanup=True):
cleanup=True,
redis_password=None):
"""Start a raylet, which is a combined local scheduler and object manager.
Args:
@@ -996,6 +1042,7 @@ def start_raylet(redis_address,
cleanup (bool): True if using Ray in local mode. If cleanup is true,
then this process will be killed by serices.cleanup() when the
Python process that imported services exits.
redis_password (str): The password of the redis server.
Returns:
The raylet socket name.
@@ -1029,6 +1076,8 @@ def start_raylet(redis_address,
sys.executable, worker_path, node_ip_address,
plasma_store_name, raylet_name, redis_address,
get_temp_root()))
if redis_password:
start_worker_command += " --redis-password {}".format(redis_password)
command = [
RAYLET_EXECUTABLE,
@@ -1042,6 +1091,7 @@ def start_raylet(redis_address,
resource_argument,
start_worker_command,
"", # Worker command for Java, not needed for Python.
redis_password or "",
]
if use_valgrind:
@@ -1063,8 +1113,10 @@ def start_raylet(redis_address,
if cleanup:
all_processes[PROCESS_TYPE_RAYLET].append(pid)
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
record_log_files_in_redis(
redis_address,
node_ip_address, [stdout_file, stderr_file],
password=redis_password)
return raylet_name
@@ -1081,7 +1133,8 @@ def start_plasma_store(node_ip_address,
plasma_directory=None,
huge_pages=False,
use_raylet=False,
plasma_store_socket_name=None):
plasma_store_socket_name=None,
redis_password=None):
"""This method starts an object store process.
Args:
@@ -1111,6 +1164,7 @@ def start_plasma_store(node_ip_address,
Store with hugetlbfs support. Requires plasma_directory.
use_raylet: True if the new raylet code path should be used. This is
not supported yet.
redis_password (str): The password of the redis server.
Return:
A tuple of the Plasma store socket name, the Plasma manager socket
@@ -1186,8 +1240,10 @@ def start_plasma_store(node_ip_address,
if cleanup:
all_processes[PROCESS_TYPE_PLASMA_STORE].append(p1)
record_log_files_in_redis(redis_address, node_ip_address,
[store_stdout_file, store_stderr_file])
record_log_files_in_redis(
redis_address,
node_ip_address, [store_stdout_file, store_stderr_file],
password=redis_password)
if not use_raylet:
if cleanup:
all_processes[PROCESS_TYPE_PLASMA_MANAGER].append(p2)
@@ -1248,7 +1304,8 @@ def start_monitor(redis_address,
stdout_file=None,
stderr_file=None,
cleanup=True,
autoscaling_config=None):
autoscaling_config=None,
redis_password=None):
"""Run a process to monitor the other processes.
Args:
@@ -1264,6 +1321,7 @@ def start_monitor(redis_address,
Python process that imported services exits. This is True by
default.
autoscaling_config: path to autoscaling config file.
redis_password (str): The password of the redis server.
"""
monitor_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "monitor.py")
@@ -1273,17 +1331,22 @@ def start_monitor(redis_address,
]
if autoscaling_config:
command.append("--autoscaling-config=" + str(autoscaling_config))
if redis_password:
command.append("--redis-password=" + redis_password)
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
if cleanup:
all_processes[PROCESS_TYPE_MONITOR].append(p)
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
record_log_files_in_redis(
redis_address,
node_ip_address, [stdout_file, stderr_file],
password=redis_password)
def start_raylet_monitor(redis_address,
stdout_file=None,
stderr_file=None,
cleanup=True):
cleanup=True,
redis_password=None):
"""Run a process to monitor the other processes.
Args:
@@ -1296,8 +1359,10 @@ def start_raylet_monitor(redis_address,
then this process will be killed by services.cleanup() when the
Python process that imported services exits. This is True by
default.
redis_password (str): The password of the redis server.
"""
gcs_ip_address, gcs_port = redis_address.split(":")
redis_password = redis_password or ""
command = [RAYLET_MONITOR_EXECUTABLE, gcs_ip_address, gcs_port]
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
if cleanup:
@@ -1314,6 +1379,7 @@ def start_ray_processes(address_info=None,
num_redis_shards=1,
redis_max_clients=None,
redis_protected_mode=False,
redis_password=None,
worker_path=None,
cleanup=True,
redirect_worker_output=False,
@@ -1359,6 +1425,8 @@ def start_ray_processes(address_info=None,
redis_protected_mode: True if we should start Redis in protected mode.
This will prevent clients from other machines from connecting and
is only done when Redis is started via ray.init().
redis_password (str): Prevents external clients without the password
from connecting to Redis if provided.
worker_path (str): The path of the source code that will be run by the
worker.
cleanup (bool): If cleanup is true, then the processes started here
@@ -1444,7 +1512,8 @@ def start_ray_processes(address_info=None,
redirect_output=True,
redirect_worker_output=redirect_worker_output,
cleanup=cleanup,
protected_mode=redis_protected_mode)
protected_mode=redis_protected_mode,
password=redis_password)
address_info["redis_address"] = redis_address
time.sleep(0.1)
@@ -1457,18 +1526,20 @@ def start_ray_processes(address_info=None,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file,
cleanup=cleanup,
autoscaling_config=autoscaling_config)
autoscaling_config=autoscaling_config,
redis_password=redis_password)
if use_raylet:
start_raylet_monitor(
redis_address,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file,
cleanup=cleanup)
cleanup=cleanup,
redis_password=redis_password)
if redis_shards == []:
# Get redis shards from primary redis instance.
redis_ip_address, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port)
host=redis_ip_address, port=redis_port, password=redis_password)
redis_shards = redis_client.lrange("RedisShards", start=0, end=-1)
redis_shards = [ray.utils.decode(shard) for shard in redis_shards]
address_info["redis_shards"] = redis_shards
@@ -1482,7 +1553,8 @@ def start_ray_processes(address_info=None,
node_ip_address,
stdout_file=log_monitor_stdout_file,
stderr_file=log_monitor_stderr_file,
cleanup=cleanup)
cleanup=cleanup,
redis_password=redis_password)
# Start the global scheduler, if necessary.
if include_global_scheduler and not use_raylet:
@@ -1493,7 +1565,8 @@ def start_ray_processes(address_info=None,
node_ip_address,
stdout_file=global_scheduler_stdout_file,
stderr_file=global_scheduler_stderr_file,
cleanup=cleanup)
cleanup=cleanup,
redis_password=redis_password)
# Initialize with existing services.
if "object_store_addresses" not in address_info:
@@ -1537,7 +1610,8 @@ def start_ray_processes(address_info=None,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
use_raylet=use_raylet,
plasma_store_socket_name=plasma_store_socket_name)
plasma_store_socket_name=plasma_store_socket_name,
redis_password=redis_password)
object_store_addresses.append(object_store_address)
time.sleep(0.1)
@@ -1575,7 +1649,8 @@ def start_ray_processes(address_info=None,
stderr_file=local_scheduler_stderr_file,
cleanup=cleanup,
resources=resources[i],
num_workers=num_local_scheduler_workers)
num_workers=num_local_scheduler_workers,
redis_password=redis_password)
local_scheduler_socket_names.append(local_scheduler_name)
# Make sure that we have exactly num_local_schedulers instances of
@@ -1599,7 +1674,8 @@ def start_ray_processes(address_info=None,
num_workers=workers_per_local_scheduler[i],
stdout_file=raylet_stdout_file,
stderr_file=raylet_stderr_file,
cleanup=cleanup))
cleanup=cleanup,
redis_password=redis_password))
if not use_raylet:
# Start any workers that the local scheduler has not already started.
@@ -1645,6 +1721,7 @@ def start_ray_node(node_ip_address,
num_workers=0,
num_local_schedulers=1,
object_store_memory=None,
redis_password=None,
worker_path=None,
cleanup=True,
redirect_worker_output=False,
@@ -1673,6 +1750,8 @@ def start_ray_node(node_ip_address,
start.
object_store_memory (int): The maximum amount of memory (in bytes) to
let the plasma store use.
redis_password (str): Prevents external clients without the password
from connecting to Redis if provided.
worker_path (str): The path of the source code that will be run by the
worker.
cleanup (bool): If cleanup is true, then the processes started here
@@ -1711,6 +1790,7 @@ def start_ray_node(node_ip_address,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
object_store_memory=object_store_memory,
redis_password=redis_password,
worker_path=worker_path,
include_log_monitor=True,
cleanup=cleanup,
@@ -1741,6 +1821,7 @@ def start_ray_head(address_info=None,
num_redis_shards=None,
redis_max_clients=None,
redis_protected_mode=False,
redis_password=None,
include_webui=True,
plasma_directory=None,
huge_pages=False,
@@ -1792,6 +1873,8 @@ def start_ray_head(address_info=None,
redis_protected_mode: True if we should start Redis in protected mode.
This will prevent clients from other machines from connecting and
is only done when Redis is started via ray.init().
redis_password (str): Prevents external clients without the password
from connecting to Redis if provided.
include_webui: True if the UI should be started and false otherwise.
plasma_directory: A directory where the Plasma memory mapped files will
be created.
@@ -1832,6 +1915,7 @@ def start_ray_head(address_info=None,
num_redis_shards=num_redis_shards,
redis_max_clients=redis_max_clients,
redis_protected_mode=redis_protected_mode,
redis_password=redis_password,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
autoscaling_config=autoscaling_config,
+65
View File
@@ -0,0 +1,65 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import pytest
import redis
import ray
@pytest.fixture
def password():
random_bytes = os.urandom(128)
if hasattr(random_bytes, "hex"):
return random_bytes.hex() # Python 3
return random_bytes.encode("hex") # Python 2
@pytest.fixture
def shutdown_only():
yield None
# The code after the yield will run as teardown code.
ray.shutdown()
class TestRedisPassword(object):
@pytest.mark.skipif(
os.environ.get("RAY_USE_NEW_GCS") != "on"
and os.environ.get("RAY_USE_XRAY"),
reason="Redis authentication works for raylet and old GCS.")
def test_exceptions(self, password, shutdown_only):
with pytest.raises(Exception):
ray.init(redis_password=password)
@pytest.mark.skipif(
os.environ.get("RAY_USE_NEW_GCS") == "on",
reason="New GCS API doesn't support Redis authentication yet.")
@pytest.mark.skipif(
not os.environ.get("RAY_USE_XRAY"),
reason="Redis authentication is not supported in legacy Ray.")
def test_redis_password(self, password, shutdown_only):
# Workaround for https://github.com/ray-project/ray/issues/3045
@ray.remote
def f():
return 1
info = ray.init(redis_password=password)
redis_address = info["redis_address"]
redis_ip, redis_port = redis_address.split(":")
# Check that we can run a task
object_id = f.remote()
ray.get(object_id)
# Check that Redis connections require a password
redis_client = redis.StrictRedis(
host=redis_ip, port=redis_port, password=None)
with pytest.raises(redis.ResponseError):
redis_client.ping()
# Check that we can connect to Redis using the provided password
redis_client = redis.StrictRedis(
host=redis_ip, port=redis_port, password=password)
assert redis_client.ping()
+38 -9
View File
@@ -1223,12 +1223,13 @@ def _initialize_serialization(driver_id, worker=global_worker):
def get_address_info_from_redis_helper(redis_address,
node_ip_address,
use_raylet=False):
use_raylet=False,
redis_password=None):
redis_ip_address, redis_port = redis_address.split(":")
# For this command to work, some other client (on the same machine as
# Redis) must have run "CONFIG SET protected-mode no".
redis_client = redis.StrictRedis(
host=redis_ip_address, port=int(redis_port))
host=redis_ip_address, port=int(redis_port), password=redis_password)
if not use_raylet:
# The client table prefix must be kept in sync with the file
@@ -1332,12 +1333,16 @@ def get_address_info_from_redis_helper(redis_address,
def get_address_info_from_redis(redis_address,
node_ip_address,
num_retries=5,
use_raylet=False):
use_raylet=False,
redis_password=None):
counter = 0
while True:
try:
return get_address_info_from_redis_helper(
redis_address, node_ip_address, use_raylet=use_raylet)
redis_address,
node_ip_address,
use_raylet=use_raylet,
redis_password=None)
except Exception:
if counter == num_retries:
raise
@@ -1405,6 +1410,7 @@ def _init(address_info=None,
resources=None,
num_redis_shards=None,
redis_max_clients=None,
redis_password=None,
plasma_directory=None,
huge_pages=False,
include_webui=True,
@@ -1460,6 +1466,8 @@ def _init(address_info=None,
the primary Redis shard.
redis_max_clients: If provided, attempt to configure Redis with this
maxclients number.
redis_password (str): Prevents external clients without the password
from connecting to Redis if provided.
plasma_directory: A directory where the Plasma memory mapped files will
be created.
huge_pages: Boolean flag indicating whether to start the Object
@@ -1544,6 +1552,7 @@ def _init(address_info=None,
resources=resources,
num_redis_shards=num_redis_shards,
redis_max_clients=redis_max_clients,
redis_password=redis_password,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
include_webui=include_webui,
@@ -1596,7 +1605,10 @@ def _init(address_info=None,
node_ip_address = services.get_node_ip_address(redis_address)
# Get the address info of the processes to connect to from Redis.
address_info = get_address_info_from_redis(
redis_address, node_ip_address, use_raylet=use_raylet)
redis_address,
node_ip_address,
use_raylet=use_raylet,
redis_password=redis_password)
# Connect this driver to Redis, the object store, and the local scheduler.
# Choose the first object store and local scheduler if there are multiple.
@@ -1628,7 +1640,8 @@ def _init(address_info=None,
object_id_seed=object_id_seed,
mode=driver_mode,
worker=global_worker,
use_raylet=use_raylet)
use_raylet=use_raylet,
redis_password=redis_password)
return address_info
@@ -1647,6 +1660,7 @@ def init(redis_address=None,
ignore_reinit_error=False,
num_redis_shards=None,
redis_max_clients=None,
redis_password=None,
plasma_directory=None,
huge_pages=False,
include_webui=True,
@@ -1709,6 +1723,8 @@ def init(redis_address=None,
the primary Redis shard.
redis_max_clients: If provided, attempt to configure Redis with this
maxclients number.
redis_password (str): Prevents external clients without the password
from connecting to Redis if provided.
plasma_directory: A directory where the Plasma memory mapped files will
be created.
huge_pages: Boolean flag indicating whether to start the Object
@@ -1750,6 +1766,11 @@ def init(redis_address=None,
# This environment variable is used in our testing setup.
logger.info("Detected environment variable 'RAY_USE_XRAY'.")
use_raylet = True
if not use_raylet and redis_password is not None:
raise Exception("Setting the 'redis_password' argument is not "
"supported in legacy Ray. To run Ray with "
"password-protected Redis ports, set "
"'use_raylet=True'.")
# Convert hostnames to numerical IP address.
if node_ip_address is not None:
@@ -1772,6 +1793,7 @@ def init(redis_address=None,
resources=resources,
num_redis_shards=num_redis_shards,
redis_max_clients=redis_max_clients,
redis_password=redis_password,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
include_webui=include_webui,
@@ -1975,7 +1997,8 @@ def connect(info,
object_id_seed=None,
mode=WORKER_MODE,
worker=global_worker,
use_raylet=False):
use_raylet=False,
redis_password=None):
"""Connect this worker to the local scheduler, to Plasma, and to Redis.
Args:
@@ -1986,6 +2009,8 @@ def connect(info,
mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and
LOCAL_MODE.
use_raylet: True if the new raylet code path should be used.
redis_password (str): Prevents external clients without the password
from connecting to Redis if provided.
"""
# Do some basic checking to make sure we didn't call ray.init twice.
error_message = "Perhaps you called ray.init twice by accident?"
@@ -2019,7 +2044,10 @@ def connect(info,
# Create a Redis client.
redis_ip_address, redis_port = info["redis_address"].split(":")
worker.redis_client = thread_safe_client(
redis.StrictRedis(host=redis_ip_address, port=int(redis_port)))
redis.StrictRedis(
host=redis_ip_address,
port=int(redis_port),
password=redis_password))
# For driver's check that the version information matches the version
# information that the Ray cluster was started with.
@@ -2060,7 +2088,8 @@ def connect(info,
[log_stdout_file, log_stderr_file])
# Create an object for interfacing with the global state.
global_state._initialize_global_state(redis_ip_address, int(redis_port))
global_state._initialize_global_state(
redis_ip_address, int(redis_port), redis_password=redis_password)
# Register the worker with Redis.
if mode == SCRIPT_MODE:
+11 -1
View File
@@ -24,6 +24,12 @@ parser.add_argument(
required=True,
type=str,
help="the address to use for Redis")
parser.add_argument(
"--redis-password",
required=False,
type=str,
default=None,
help="the password to use for Redis")
parser.add_argument(
"--object-store-name",
required=True,
@@ -67,6 +73,7 @@ if __name__ == "__main__":
info = {
"node_ip_address": args.node_ip_address,
"redis_address": args.redis_address,
"redis_password": args.redis_password,
"store_socket_name": args.object_store_name,
"manager_socket_name": args.object_store_manager_name,
"local_scheduler_socket_name": args.local_scheduler_name,
@@ -81,7 +88,10 @@ if __name__ == "__main__":
tempfile_services.set_temp_root(args.temp_dir)
ray.worker.connect(
info, mode=ray.WORKER_MODE, use_raylet=(args.raylet_name is not None))
info,
mode=ray.WORKER_MODE,
use_raylet=(args.raylet_name is not None),
redis_password=args.redis_password)
error_explanation = """
This error is unexpected and should not have happened. Somehow a worker