Add password authentication to Redis ports (#2952)

* Implement Redis authentication

* Throw exception for legacy Ray

* Add test

* Formatting

* Fix bugs in CLI

* Fix bugs in Raylet

* Move default password to constants.h

* Use pytest.fixture

* Fix bug

* Authenticate using formatted strings

* Add missing passwords

* Add test

* Improve authentication of async contexts

* Disable Redis authentication for credis

* Update test for credis

* Fix rebase artifacts

* Fix formatting

* Add workaround for issue #3045

* Increase timeout for test

* Improve C++ readability

* Fixes for CLI

* Add security docs

* Address comments

* Address comments

* Adress comments

* Use ray.get

* Fix lint
This commit is contained in:
Peter Schafhalter
2018-10-16 22:48:30 -07:00
committed by Philipp Moritz
parent a9e454f6fd
commit a41bbc10ef
22 changed files with 462 additions and 115 deletions
+136 -52
View File
@@ -261,7 +261,10 @@ def get_node_ip_address(address="8.8.8.8:53"):
return node_ip_address
def record_log_files_in_redis(redis_address, node_ip_address, log_files):
def record_log_files_in_redis(redis_address,
node_ip_address,
log_files,
password=None):
"""Record in Redis that a new log file has been created.
This is used so that each log monitor can check Redis and figure out which
@@ -273,23 +276,24 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files):
on.
log_files: A list of file handles for the log files. If one of the file
handles is None, we ignore it.
password (str): The password of the redis server.
"""
for log_file in log_files:
if log_file is not None:
redis_ip_address, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port)
host=redis_ip_address, port=redis_port, password=password)
# The name of the key storing the list of log filenames for this IP
# address.
log_file_list_key = "LOG_FILENAMES:{}".format(node_ip_address)
redis_client.rpush(log_file_list_key, log_file.name)
def create_redis_client(redis_address):
def create_redis_client(redis_address, password=None):
"""Create a Redis client.
Args:
The IP address and port of the Redis server.
The IP address, port, and password of the Redis server.
Returns:
A Redis client.
@@ -297,10 +301,14 @@ def create_redis_client(redis_address):
redis_ip_address, redis_port = redis_address.split(":")
# For this command to work, some other client (on the same machine
# as Redis) must have run "CONFIG SET protected-mode no".
return redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
return redis.StrictRedis(
host=redis_ip_address, port=int(redis_port), password=password)
def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
def wait_for_redis_to_start(redis_ip_address,
redis_port,
password=None,
num_retries=5):
"""Wait for a Redis server to be available.
This is accomplished by creating a Redis client and sending a random
@@ -309,13 +317,15 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5):
Args:
redis_ip_address (str): The IP address of the redis server.
redis_port (int): The port of the redis server.
password (str): The password of the redis server.
num_retries (int): The number of times to try connecting with redis.
The client will sleep for one second between attempts.
Raises:
Exception: An exception is raised if we could not connect with Redis.
"""
redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port)
redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port, password=password)
# Wait for the Redis server to start.
counter = 0
while counter < num_retries:
@@ -425,6 +435,7 @@ def start_redis(node_ip_address,
redirect_worker_output=False,
cleanup=True,
protected_mode=False,
password=None,
use_credis=None):
"""Start the Redis global state store.
@@ -451,6 +462,8 @@ def start_redis(node_ip_address,
then all Redis processes started by this method will be killed by
services.cleanup() when the Python process that imported services
exits.
password (str): Prevents external clients without the password
from connecting to Redis if provided.
use_credis: If True, additionally load the chain-replicated libraries
into the redis servers. Defaults to None, which means its value is
set by the presence of "RAY_USE_NEW_GCS" in os.environ.
@@ -469,6 +482,13 @@ def start_redis(node_ip_address,
if use_credis is None:
use_credis = ("RAY_USE_NEW_GCS" in os.environ)
if use_credis and password is not None:
# TODO(pschafhalter) remove this once credis supports
# authenticating Redis ports
raise Exception("Setting the `redis_password` argument is not "
"supported in credis. To run Ray with "
"password-protected Redis ports, ensure that "
"the environment variable `RAY_USE_NEW_GCS=off`.")
if not use_credis:
assigned_port, _ = _start_redis_instance(
node_ip_address=node_ip_address,
@@ -477,7 +497,8 @@ def start_redis(node_ip_address,
stdout_file=redis_stdout_file,
stderr_file=redis_stderr_file,
cleanup=cleanup,
protected_mode=protected_mode)
protected_mode=protected_mode,
password=password)
else:
assigned_port, _ = _start_redis_instance(
node_ip_address=node_ip_address,
@@ -491,20 +512,23 @@ def start_redis(node_ip_address,
# It is important to load the credis module BEFORE the ray module,
# as the latter contains an extern declaration that the former
# supplies.
modules=[CREDIS_MASTER_MODULE, REDIS_MODULE])
modules=[CREDIS_MASTER_MODULE, REDIS_MODULE],
password=password)
if port is not None:
assert assigned_port == port
port = assigned_port
redis_address = address(node_ip_address, port)
redis_client = redis.StrictRedis(host=node_ip_address, port=port)
redis_client = redis.StrictRedis(
host=node_ip_address, port=port, password=password)
# Store whether we're using the raylet code path or not.
redis_client.set("UseRaylet", 1 if use_raylet else 0)
# Register the number of Redis shards in the primary shard, so that clients
# know how many redis shards to expect under RedisShards.
primary_redis_client = redis.StrictRedis(host=node_ip_address, port=port)
primary_redis_client = redis.StrictRedis(
host=node_ip_address, port=port, password=password)
primary_redis_client.set("NumRedisShards", str(num_redis_shards))
# Put the redirect_worker_output bool in the Redis shard so that workers
@@ -529,7 +553,8 @@ def start_redis(node_ip_address,
stdout_file=redis_stdout_file,
stderr_file=redis_stderr_file,
cleanup=cleanup,
protected_mode=protected_mode)
protected_mode=protected_mode,
password=password)
else:
assert num_redis_shards == 1, \
"For now, RAY_USE_NEW_GCS supports 1 shard, and credis "\
@@ -542,6 +567,7 @@ def start_redis(node_ip_address,
stderr_file=redis_stderr_file,
cleanup=cleanup,
protected_mode=protected_mode,
password=password,
executable=CREDIS_EXECUTABLE,
# It is important to load the credis module BEFORE the ray
# module, as the latter contains an extern declaration that the
@@ -557,7 +583,7 @@ def start_redis(node_ip_address,
if use_credis:
shard_client = redis.StrictRedis(
host=node_ip_address, port=redis_shard_port)
host=node_ip_address, port=redis_shard_port, password=password)
# Configure the chain state.
primary_redis_client.execute_command("MASTER.ADD", node_ip_address,
redis_shard_port)
@@ -591,6 +617,7 @@ def _start_redis_instance(node_ip_address="127.0.0.1",
stderr_file=None,
cleanup=True,
protected_mode=False,
password=None,
executable=REDIS_EXECUTABLE,
modules=None):
"""Start a single Redis server.
@@ -614,6 +641,8 @@ def _start_redis_instance(node_ip_address="127.0.0.1",
mode. This will prevent clients on other machines from connecting
and is only used when the Redis servers are started via ray.init()
as opposed to ray start.
password (str): Prevents external clients without the password
from connecting to Redis if provided.
executable (str): Full path tho the redis-server executable.
modules (list of str): A list of pathnames, pointing to the redis
module(s) that will be loaded in this redis server. If None, load
@@ -654,6 +683,8 @@ def _start_redis_instance(node_ip_address="127.0.0.1",
command = [executable]
if protected_mode:
command += [redis_config_filename]
if password:
command += ["--requirepass", password]
command += (
["--port", str(port), "--loglevel", "warning"] + load_module_args)
@@ -672,9 +703,10 @@ def _start_redis_instance(node_ip_address="127.0.0.1",
stdout_file.name, stderr_file.name))
# Create a Redis client just for configuring Redis.
redis_client = redis.StrictRedis(host="127.0.0.1", port=port)
redis_client = redis.StrictRedis(
host="127.0.0.1", port=port, password=password)
# Wait for the Redis server to start.
wait_for_redis_to_start("127.0.0.1", port)
wait_for_redis_to_start("127.0.0.1", port, password=password)
# Configure Redis to generate keyspace notifications. TODO(rkn): Change
# this to only generate notifications for the export keys.
redis_client.config_set("notify-keyspace-events", "Kl")
@@ -719,8 +751,9 @@ def _start_redis_instance(node_ip_address="127.0.0.1",
redis_client.set("redis_start_time", time.time())
# Record the log files in Redis.
record_log_files_in_redis(
address(node_ip_address, port), node_ip_address,
[stdout_file, stderr_file])
address(node_ip_address, port),
node_ip_address, [stdout_file, stderr_file],
password=password)
return port, p
@@ -728,7 +761,8 @@ def start_log_monitor(redis_address,
node_ip_address,
stdout_file=None,
stderr_file=None,
cleanup=cleanup):
cleanup=cleanup,
redis_password=None):
"""Start a log monitor process.
Args:
@@ -742,27 +776,31 @@ def start_log_monitor(redis_address,
cleanup (bool): True if using Ray in local mode. If cleanup is true,
then this process will be killed by services.cleanup() when the
Python process that imported services exits.
redis_password (str): The password of the redis server.
"""
log_monitor_filepath = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "log_monitor.py")
p = subprocess.Popen(
[
sys.executable, "-u", log_monitor_filepath, "--redis-address",
redis_address, "--node-ip-address", node_ip_address
],
stdout=stdout_file,
stderr=stderr_file)
command = [
sys.executable, "-u", log_monitor_filepath, "--redis-address",
redis_address, "--node-ip-address", node_ip_address
]
if redis_password:
command += ["--redis-password", redis_password]
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
if cleanup:
all_processes[PROCESS_TYPE_LOG_MONITOR].append(p)
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
record_log_files_in_redis(
redis_address,
node_ip_address, [stdout_file, stderr_file],
password=redis_password)
def start_global_scheduler(redis_address,
node_ip_address,
stdout_file=None,
stderr_file=None,
cleanup=True):
cleanup=True,
redis_password=None):
"""Start a global scheduler process.
Args:
@@ -776,6 +814,7 @@ def start_global_scheduler(redis_address,
cleanup (bool): True if using Ray in local mode. If cleanup is true,
then this process will be killed by services.cleanup() when the
Python process that imported services exits.
redis_password (str): The password of the redis server.
"""
p = global_scheduler.start_global_scheduler(
redis_address,
@@ -784,8 +823,10 @@ def start_global_scheduler(redis_address,
stderr_file=stderr_file)
if cleanup:
all_processes[PROCESS_TYPE_GLOBAL_SCHEDULER].append(p)
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
record_log_files_in_redis(
redis_address,
node_ip_address, [stdout_file, stderr_file],
password=redis_password)
def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True):
@@ -911,7 +952,8 @@ def start_local_scheduler(redis_address,
stderr_file=None,
cleanup=True,
resources=None,
num_workers=0):
num_workers=0,
redis_password=None):
"""Start a local scheduler process.
Args:
@@ -935,6 +977,7 @@ def start_local_scheduler(redis_address,
quantity of that resource.
num_workers (int): The number of workers that the local scheduler
should start.
redis_password (str): The password of the redis server.
Return:
The name of the local scheduler socket.
@@ -957,8 +1000,10 @@ def start_local_scheduler(redis_address,
num_workers=num_workers)
if cleanup:
all_processes[PROCESS_TYPE_LOCAL_SCHEDULER].append(p)
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
record_log_files_in_redis(
redis_address,
node_ip_address, [stdout_file, stderr_file],
password=redis_password)
return local_scheduler_name
@@ -973,7 +1018,8 @@ def start_raylet(redis_address,
use_profiler=False,
stdout_file=None,
stderr_file=None,
cleanup=True):
cleanup=True,
redis_password=None):
"""Start a raylet, which is a combined local scheduler and object manager.
Args:
@@ -996,6 +1042,7 @@ def start_raylet(redis_address,
cleanup (bool): True if using Ray in local mode. If cleanup is true,
then this process will be killed by serices.cleanup() when the
Python process that imported services exits.
redis_password (str): The password of the redis server.
Returns:
The raylet socket name.
@@ -1029,6 +1076,8 @@ def start_raylet(redis_address,
sys.executable, worker_path, node_ip_address,
plasma_store_name, raylet_name, redis_address,
get_temp_root()))
if redis_password:
start_worker_command += " --redis-password {}".format(redis_password)
command = [
RAYLET_EXECUTABLE,
@@ -1042,6 +1091,7 @@ def start_raylet(redis_address,
resource_argument,
start_worker_command,
"", # Worker command for Java, not needed for Python.
redis_password or "",
]
if use_valgrind:
@@ -1063,8 +1113,10 @@ def start_raylet(redis_address,
if cleanup:
all_processes[PROCESS_TYPE_RAYLET].append(pid)
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
record_log_files_in_redis(
redis_address,
node_ip_address, [stdout_file, stderr_file],
password=redis_password)
return raylet_name
@@ -1081,7 +1133,8 @@ def start_plasma_store(node_ip_address,
plasma_directory=None,
huge_pages=False,
use_raylet=False,
plasma_store_socket_name=None):
plasma_store_socket_name=None,
redis_password=None):
"""This method starts an object store process.
Args:
@@ -1111,6 +1164,7 @@ def start_plasma_store(node_ip_address,
Store with hugetlbfs support. Requires plasma_directory.
use_raylet: True if the new raylet code path should be used. This is
not supported yet.
redis_password (str): The password of the redis server.
Return:
A tuple of the Plasma store socket name, the Plasma manager socket
@@ -1186,8 +1240,10 @@ def start_plasma_store(node_ip_address,
if cleanup:
all_processes[PROCESS_TYPE_PLASMA_STORE].append(p1)
record_log_files_in_redis(redis_address, node_ip_address,
[store_stdout_file, store_stderr_file])
record_log_files_in_redis(
redis_address,
node_ip_address, [store_stdout_file, store_stderr_file],
password=redis_password)
if not use_raylet:
if cleanup:
all_processes[PROCESS_TYPE_PLASMA_MANAGER].append(p2)
@@ -1248,7 +1304,8 @@ def start_monitor(redis_address,
stdout_file=None,
stderr_file=None,
cleanup=True,
autoscaling_config=None):
autoscaling_config=None,
redis_password=None):
"""Run a process to monitor the other processes.
Args:
@@ -1264,6 +1321,7 @@ def start_monitor(redis_address,
Python process that imported services exits. This is True by
default.
autoscaling_config: path to autoscaling config file.
redis_password (str): The password of the redis server.
"""
monitor_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "monitor.py")
@@ -1273,17 +1331,22 @@ def start_monitor(redis_address,
]
if autoscaling_config:
command.append("--autoscaling-config=" + str(autoscaling_config))
if redis_password:
command.append("--redis-password=" + redis_password)
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
if cleanup:
all_processes[PROCESS_TYPE_MONITOR].append(p)
record_log_files_in_redis(redis_address, node_ip_address,
[stdout_file, stderr_file])
record_log_files_in_redis(
redis_address,
node_ip_address, [stdout_file, stderr_file],
password=redis_password)
def start_raylet_monitor(redis_address,
stdout_file=None,
stderr_file=None,
cleanup=True):
cleanup=True,
redis_password=None):
"""Run a process to monitor the other processes.
Args:
@@ -1296,8 +1359,10 @@ def start_raylet_monitor(redis_address,
then this process will be killed by services.cleanup() when the
Python process that imported services exits. This is True by
default.
redis_password (str): The password of the redis server.
"""
gcs_ip_address, gcs_port = redis_address.split(":")
redis_password = redis_password or ""
command = [RAYLET_MONITOR_EXECUTABLE, gcs_ip_address, gcs_port]
p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file)
if cleanup:
@@ -1314,6 +1379,7 @@ def start_ray_processes(address_info=None,
num_redis_shards=1,
redis_max_clients=None,
redis_protected_mode=False,
redis_password=None,
worker_path=None,
cleanup=True,
redirect_worker_output=False,
@@ -1359,6 +1425,8 @@ def start_ray_processes(address_info=None,
redis_protected_mode: True if we should start Redis in protected mode.
This will prevent clients from other machines from connecting and
is only done when Redis is started via ray.init().
redis_password (str): Prevents external clients without the password
from connecting to Redis if provided.
worker_path (str): The path of the source code that will be run by the
worker.
cleanup (bool): If cleanup is true, then the processes started here
@@ -1444,7 +1512,8 @@ def start_ray_processes(address_info=None,
redirect_output=True,
redirect_worker_output=redirect_worker_output,
cleanup=cleanup,
protected_mode=redis_protected_mode)
protected_mode=redis_protected_mode,
password=redis_password)
address_info["redis_address"] = redis_address
time.sleep(0.1)
@@ -1457,18 +1526,20 @@ def start_ray_processes(address_info=None,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file,
cleanup=cleanup,
autoscaling_config=autoscaling_config)
autoscaling_config=autoscaling_config,
redis_password=redis_password)
if use_raylet:
start_raylet_monitor(
redis_address,
stdout_file=monitor_stdout_file,
stderr_file=monitor_stderr_file,
cleanup=cleanup)
cleanup=cleanup,
redis_password=redis_password)
if redis_shards == []:
# Get redis shards from primary redis instance.
redis_ip_address, redis_port = redis_address.split(":")
redis_client = redis.StrictRedis(
host=redis_ip_address, port=redis_port)
host=redis_ip_address, port=redis_port, password=redis_password)
redis_shards = redis_client.lrange("RedisShards", start=0, end=-1)
redis_shards = [ray.utils.decode(shard) for shard in redis_shards]
address_info["redis_shards"] = redis_shards
@@ -1482,7 +1553,8 @@ def start_ray_processes(address_info=None,
node_ip_address,
stdout_file=log_monitor_stdout_file,
stderr_file=log_monitor_stderr_file,
cleanup=cleanup)
cleanup=cleanup,
redis_password=redis_password)
# Start the global scheduler, if necessary.
if include_global_scheduler and not use_raylet:
@@ -1493,7 +1565,8 @@ def start_ray_processes(address_info=None,
node_ip_address,
stdout_file=global_scheduler_stdout_file,
stderr_file=global_scheduler_stderr_file,
cleanup=cleanup)
cleanup=cleanup,
redis_password=redis_password)
# Initialize with existing services.
if "object_store_addresses" not in address_info:
@@ -1537,7 +1610,8 @@ def start_ray_processes(address_info=None,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
use_raylet=use_raylet,
plasma_store_socket_name=plasma_store_socket_name)
plasma_store_socket_name=plasma_store_socket_name,
redis_password=redis_password)
object_store_addresses.append(object_store_address)
time.sleep(0.1)
@@ -1575,7 +1649,8 @@ def start_ray_processes(address_info=None,
stderr_file=local_scheduler_stderr_file,
cleanup=cleanup,
resources=resources[i],
num_workers=num_local_scheduler_workers)
num_workers=num_local_scheduler_workers,
redis_password=redis_password)
local_scheduler_socket_names.append(local_scheduler_name)
# Make sure that we have exactly num_local_schedulers instances of
@@ -1599,7 +1674,8 @@ def start_ray_processes(address_info=None,
num_workers=workers_per_local_scheduler[i],
stdout_file=raylet_stdout_file,
stderr_file=raylet_stderr_file,
cleanup=cleanup))
cleanup=cleanup,
redis_password=redis_password))
if not use_raylet:
# Start any workers that the local scheduler has not already started.
@@ -1645,6 +1721,7 @@ def start_ray_node(node_ip_address,
num_workers=0,
num_local_schedulers=1,
object_store_memory=None,
redis_password=None,
worker_path=None,
cleanup=True,
redirect_worker_output=False,
@@ -1673,6 +1750,8 @@ def start_ray_node(node_ip_address,
start.
object_store_memory (int): The maximum amount of memory (in bytes) to
let the plasma store use.
redis_password (str): Prevents external clients without the password
from connecting to Redis if provided.
worker_path (str): The path of the source code that will be run by the
worker.
cleanup (bool): If cleanup is true, then the processes started here
@@ -1711,6 +1790,7 @@ def start_ray_node(node_ip_address,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers,
object_store_memory=object_store_memory,
redis_password=redis_password,
worker_path=worker_path,
include_log_monitor=True,
cleanup=cleanup,
@@ -1741,6 +1821,7 @@ def start_ray_head(address_info=None,
num_redis_shards=None,
redis_max_clients=None,
redis_protected_mode=False,
redis_password=None,
include_webui=True,
plasma_directory=None,
huge_pages=False,
@@ -1792,6 +1873,8 @@ def start_ray_head(address_info=None,
redis_protected_mode: True if we should start Redis in protected mode.
This will prevent clients from other machines from connecting and
is only done when Redis is started via ray.init().
redis_password (str): Prevents external clients without the password
from connecting to Redis if provided.
include_webui: True if the UI should be started and false otherwise.
plasma_directory: A directory where the Plasma memory mapped files will
be created.
@@ -1832,6 +1915,7 @@ def start_ray_head(address_info=None,
num_redis_shards=num_redis_shards,
redis_max_clients=redis_max_clients,
redis_protected_mode=redis_protected_mode,
redis_password=redis_password,
plasma_directory=plasma_directory,
huge_pages=huge_pages,
autoscaling_config=autoscaling_config,