From a41bbc10ef5c3e0a8a3e8a36dbbe18171b556ada Mon Sep 17 00:00:00 2001 From: Peter Schafhalter Date: Tue, 16 Oct 2018 22:48:30 -0700 Subject: [PATCH] Add password authentication to Redis ports (#2952) * Implement Redis authentication * Throw exception for legacy Ray * Add test * Formatting * Fix bugs in CLI * Fix bugs in Raylet * Move default password to constants.h * Use pytest.fixture * Fix bug * Authenticate using formatted strings * Add missing passwords * Add test * Improve authentication of async contexts * Disable Redis authentication for credis * Update test for credis * Fix rebase artifacts * Fix formatting * Add workaround for issue #3045 * Increase timeout for test * Improve C++ readability * Fixes for CLI * Add security docs * Address comments * Address comments * Adress comments * Use ray.get * Fix lint --- .travis.yml | 2 + doc/source/index.rst | 1 + doc/source/security.rst | 55 +++++ python/ray/experimental/state.py | 9 +- python/ray/log_monitor.py | 21 +- python/ray/monitor.py | 27 ++- python/ray/scripts/scripts.py | 39 +++- python/ray/services.py | 188 +++++++++++++----- python/ray/test/test_ray_init.py | 65 ++++++ python/ray/worker.py | 47 ++++- python/ray/workers/default_worker.py | 12 +- src/ray/gcs/client.cc | 30 ++- src/ray/gcs/client.h | 7 +- src/ray/gcs/redis_context.cc | 31 ++- src/ray/gcs/redis_context.h | 3 +- src/ray/raylet/main.cc | 10 +- src/ray/raylet/monitor.cc | 4 +- src/ray/raylet/monitor.h | 2 +- src/ray/raylet/monitor_main.cc | 5 +- .../raylet/object_manager_integration_test.cc | 4 +- src/ray/raylet/raylet.cc | 10 +- src/ray/raylet/raylet.h | 5 +- 22 files changed, 462 insertions(+), 115 deletions(-) create mode 100644 doc/source/security.rst create mode 100644 python/ray/test/test_ray_init.py diff --git a/.travis.yml b/.travis.yml index 35743b764..8416fa138 100644 --- a/.travis.yml +++ b/.travis.yml @@ -133,6 +133,7 @@ matrix: - python -m pytest -v python/ray/test/test_global_state.py - python -m pytest -v python/ray/test/test_queue.py + - python -m pytest -v python/ray/test/test_ray_init.py - python -m pytest -v test/xray_test.py - python -m pytest -v test/runtest.py @@ -208,6 +209,7 @@ script: - python -m pytest -v python/ray/test/test_global_state.py - python -m pytest -v python/ray/test/test_queue.py + - python -m pytest -v python/ray/test/test_ray_init.py - python -m pytest -v test/xray_test.py - python -m pytest -v test/runtest.py diff --git a/doc/source/index.rst b/doc/source/index.rst index d8870bbaa..5268054b7 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -135,6 +135,7 @@ Ray comes with libraries that accelerate deep learning and reinforcement learnin troubleshooting.rst user-profiling.rst + security.rst development.rst profiling.rst contact.rst diff --git a/doc/source/security.rst b/doc/source/security.rst new file mode 100644 index 000000000..6b636c668 --- /dev/null +++ b/doc/source/security.rst @@ -0,0 +1,55 @@ +Security +======== + +This document describes best security practices for using Ray. + +Intended Use and Threat Model +----------------------------- + +Ray instances should run on a secure network without public facing ports. +The most common threat for Ray instances is unauthorized access to Redis, +which can be exploited to gain shell access and run arbitray code. +The best fix is to run Ray instances on a secure, trusted network. + +Running Ray on a secured network is not always feasible, so Ray +provides some basic security features: + + +Redis Port Authentication +------------------------- + +To prevent exploits via unauthorized Redis access, Ray provides the option to +password-protect Redis ports. While this is not a replacement for running Ray +behind a firewall, this feature is useful for instances exposed to the internet +where configuring a firewall is not possible. Because Redis is +very fast at serving queries, the chosen password should be long. + +Redis authentication is only supported on the raylet code path. + +To add authentication via the Python API, start Ray using: + +.. code-block:: python + + ray.init(redis_password="password") + +To add authentication via the CLI, or connect to an existing Ray instance with +password-protected Redis ports: + +.. code-block:: bash + + ray start [--head] --redis-password="password" + +While Redis port authentication may protect against external attackers, +Ray does not encrypt traffic between nodes so man-in-the-middle attacks are +possible for clusters on untrusted networks. + +Cloud Security +-------------- + +Launching Ray clusters on AWS or GCP using the ``ray up`` command +automatically configures security groups that prevent external Redis access. + +References +---------- + +- The `Redis security documentation ` diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index eab71993c..906d650d2 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -78,6 +78,7 @@ class GlobalState(object): def _initialize_global_state(self, redis_ip_address, redis_port, + redis_password=None, timeout=20): """Initialize the GlobalState object by connecting to Redis. @@ -89,9 +90,10 @@ class GlobalState(object): redis_ip_address: The IP address of the node that the Redis server lives on. redis_port: The port that the Redis server is listening on. + redis_password: The password of the redis server. """ self.redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port) + host=redis_ip_address, port=redis_port, password=redis_password) start_time = time.time() @@ -143,7 +145,10 @@ class GlobalState(object): for ip_address_port in ip_address_ports: shard_address, shard_port = ip_address_port.split(b":") self.redis_clients.append( - redis.StrictRedis(host=shard_address, port=shard_port)) + redis.StrictRedis( + host=shard_address, + port=shard_port, + password=redis_password)) def _execute_command(self, key, *args): """Execute a Redis command on the appropriate Redis shard based on key. diff --git a/python/ray/log_monitor.py b/python/ray/log_monitor.py index 13a62a98a..2cd6fc40a 100644 --- a/python/ray/log_monitor.py +++ b/python/ray/log_monitor.py @@ -35,11 +35,15 @@ class LogMonitor(object): handle for that file. """ - def __init__(self, redis_ip_address, redis_port, node_ip_address): + def __init__(self, + redis_ip_address, + redis_port, + node_ip_address, + redis_password=None): """Initialize the log monitor object.""" self.node_ip_address = node_ip_address self.redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port) + host=redis_ip_address, port=redis_port, password=redis_password) self.log_files = {} self.log_file_handles = {} self.files_to_ignore = set() @@ -130,6 +134,12 @@ if __name__ == "__main__": required=True, type=str, help="The IP address of the node this process is on.") + parser.add_argument( + "--redis-password", + required=False, + type=str, + default=None, + help="the password to use for Redis") parser.add_argument( "--logging-level", required=False, @@ -151,6 +161,9 @@ if __name__ == "__main__": redis_ip_address = get_ip_address(args.redis_address) redis_port = get_port(args.redis_address) - log_monitor = LogMonitor(redis_ip_address, redis_port, - args.node_ip_address) + log_monitor = LogMonitor( + redis_ip_address, + redis_port, + args.node_ip_address, + redis_password=args.redis_password) log_monitor.run() diff --git a/python/ray/monitor.py b/python/ray/monitor.py index e5c2279b7..6212de23e 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -70,13 +70,18 @@ class Monitor(object): managers that were up at one point and have died since then. """ - def __init__(self, redis_address, redis_port, autoscaling_config): + def __init__(self, + redis_address, + redis_port, + autoscaling_config, + redis_password=None): # Initialize the Redis clients. self.state = ray.experimental.state.GlobalState() - self.state._initialize_global_state(redis_address, redis_port) + self.state._initialize_global_state( + redis_address, redis_port, redis_password=redis_password) self.use_raylet = self.state.use_raylet self.redis = redis.StrictRedis( - host=redis_address, port=redis_port, db=0) + host=redis_address, port=redis_port, db=0, password=redis_password) # Setup subscriptions to the primary Redis server and the Redis shards. self.primary_subscribe_client = self.redis.pubsub( ignore_subscribe_messages=True) @@ -118,7 +123,9 @@ class Monitor(object): else: addr_port = addr_port[0].split(b":") self.redis_shard = redis.StrictRedis( - host=addr_port[0], port=addr_port[1]) + host=addr_port[0], + port=addr_port[1], + password=redis_password) try: self.redis_shard.execute_command("HEAD.FLUSH 0") except redis.exceptions.ResponseError as e: @@ -773,6 +780,12 @@ if __name__ == "__main__": required=False, type=str, help="the path to the autoscaling config file") + parser.add_argument( + "--redis-password", + required=False, + type=str, + default=None, + help="the password to use for Redis") parser.add_argument( "--logging-level", required=False, @@ -798,7 +811,11 @@ if __name__ == "__main__": else: autoscaling_config = None - monitor = Monitor(redis_ip_address, redis_port, autoscaling_config) + monitor = Monitor( + redis_ip_address, + redis_port, + autoscaling_config, + redis_password=args.redis_password) try: monitor.run() diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index f8e0c5484..dfbbee272 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -89,6 +89,11 @@ def cli(logging_level, logging_format): type=int, help=("If provided, attempt to configure Redis with this " "maximum number of clients.")) +@click.option( + "--redis-password", + required=False, + type=str, + help="If provided, secure Redis ports with this password") @click.option( "--redis-shard-ports", required=False, @@ -190,10 +195,11 @@ def cli(logging_level, logging_format): default=None, help="manually specify the root temporary dir of the Ray process") def start(node_ip_address, redis_address, redis_port, num_redis_shards, - redis_max_clients, redis_shard_ports, object_manager_port, - object_store_memory, num_workers, num_cpus, num_gpus, resources, - head, no_ui, block, plasma_directory, huge_pages, autoscaling_config, - use_raylet, no_redirect_worker_output, no_redirect_output, + redis_max_clients, redis_password, redis_shard_ports, + object_manager_port, object_store_memory, num_workers, num_cpus, + num_gpus, resources, head, no_ui, block, plasma_directory, + huge_pages, autoscaling_config, use_raylet, + no_redirect_worker_output, no_redirect_output, plasma_store_socket_name, raylet_socket_name, temp_dir): # Convert hostnames to numerical IP address. if node_ip_address is not None: @@ -205,6 +211,11 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, # This environment variable is used in our testing setup. logger.info("Detected environment variable 'RAY_USE_XRAY'.") use_raylet = True + if not use_raylet and redis_password is not None: + raise Exception("Setting the 'redis-password' argument is not " + "supported in legacy Ray. To run Ray with " + "password-protected Redis ports, pass " + "the '--use-raylet' flag.") try: resources = json.loads(resources) @@ -269,6 +280,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, redis_protected_mode=False, + redis_password=redis_password, include_webui=(not no_ui), plasma_directory=plasma_directory, huge_pages=huge_pages, @@ -281,16 +293,20 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, logger.info( "\nStarted Ray on this node. You can add additional nodes to " "the cluster by calling\n\n" - " ray start --redis-address {}\n\n" + " ray start --redis-address {}{}{}\n\n" "from the node you wish to add. You can connect a driver to the " "cluster from Python by running\n\n" " import ray\n" - " ray.init(redis_address=\"{}\")\n\n" + " ray.init(redis_address=\"{}{}{}\")\n\n" "If you have trouble connecting from a different machine, check " "that your firewall is configured properly. If you wish to " "terminate the processes that have been started, run\n\n" - " ray stop".format(address_info["redis_address"], - address_info["redis_address"])) + " ray stop".format( + address_info["redis_address"], " --redis-password " + if redis_password else "", redis_password if redis_password + else "", address_info["redis_address"], "\", redis_password=\"" + if redis_password else "", redis_password + if redis_password else "")) else: # Start Ray on a non-head node. if redis_port is not None: @@ -315,10 +331,12 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, # Wait for the Redis server to be started. And throw an exception if we # can't connect to it. - services.wait_for_redis_to_start(redis_ip_address, int(redis_port)) + services.wait_for_redis_to_start( + redis_ip_address, int(redis_port), password=redis_password) # Create a Redis client. - redis_client = services.create_redis_client(redis_address) + redis_client = services.create_redis_client( + redis_address, password=redis_password) # Check that the verion information on this node matches the version # information that the cluster was started with. @@ -339,6 +357,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, object_manager_ports=[object_manager_port], num_workers=num_workers, object_store_memory=object_store_memory, + redis_password=redis_password, cleanup=False, redirect_worker_output=not no_redirect_worker_output, redirect_output=not no_redirect_output, diff --git a/python/ray/services.py b/python/ray/services.py index 9b1592e7b..e572b657f 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -261,7 +261,10 @@ def get_node_ip_address(address="8.8.8.8:53"): return node_ip_address -def record_log_files_in_redis(redis_address, node_ip_address, log_files): +def record_log_files_in_redis(redis_address, + node_ip_address, + log_files, + password=None): """Record in Redis that a new log file has been created. This is used so that each log monitor can check Redis and figure out which @@ -273,23 +276,24 @@ def record_log_files_in_redis(redis_address, node_ip_address, log_files): on. log_files: A list of file handles for the log files. If one of the file handles is None, we ignore it. + password (str): The password of the redis server. """ for log_file in log_files: if log_file is not None: redis_ip_address, redis_port = redis_address.split(":") redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port) + host=redis_ip_address, port=redis_port, password=password) # The name of the key storing the list of log filenames for this IP # address. log_file_list_key = "LOG_FILENAMES:{}".format(node_ip_address) redis_client.rpush(log_file_list_key, log_file.name) -def create_redis_client(redis_address): +def create_redis_client(redis_address, password=None): """Create a Redis client. Args: - The IP address and port of the Redis server. + The IP address, port, and password of the Redis server. Returns: A Redis client. @@ -297,10 +301,14 @@ def create_redis_client(redis_address): redis_ip_address, redis_port = redis_address.split(":") # For this command to work, some other client (on the same machine # as Redis) must have run "CONFIG SET protected-mode no". - return redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) + return redis.StrictRedis( + host=redis_ip_address, port=int(redis_port), password=password) -def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5): +def wait_for_redis_to_start(redis_ip_address, + redis_port, + password=None, + num_retries=5): """Wait for a Redis server to be available. This is accomplished by creating a Redis client and sending a random @@ -309,13 +317,15 @@ def wait_for_redis_to_start(redis_ip_address, redis_port, num_retries=5): Args: redis_ip_address (str): The IP address of the redis server. redis_port (int): The port of the redis server. + password (str): The password of the redis server. num_retries (int): The number of times to try connecting with redis. The client will sleep for one second between attempts. Raises: Exception: An exception is raised if we could not connect with Redis. """ - redis_client = redis.StrictRedis(host=redis_ip_address, port=redis_port) + redis_client = redis.StrictRedis( + host=redis_ip_address, port=redis_port, password=password) # Wait for the Redis server to start. counter = 0 while counter < num_retries: @@ -425,6 +435,7 @@ def start_redis(node_ip_address, redirect_worker_output=False, cleanup=True, protected_mode=False, + password=None, use_credis=None): """Start the Redis global state store. @@ -451,6 +462,8 @@ def start_redis(node_ip_address, then all Redis processes started by this method will be killed by services.cleanup() when the Python process that imported services exits. + password (str): Prevents external clients without the password + from connecting to Redis if provided. use_credis: If True, additionally load the chain-replicated libraries into the redis servers. Defaults to None, which means its value is set by the presence of "RAY_USE_NEW_GCS" in os.environ. @@ -469,6 +482,13 @@ def start_redis(node_ip_address, if use_credis is None: use_credis = ("RAY_USE_NEW_GCS" in os.environ) + if use_credis and password is not None: + # TODO(pschafhalter) remove this once credis supports + # authenticating Redis ports + raise Exception("Setting the `redis_password` argument is not " + "supported in credis. To run Ray with " + "password-protected Redis ports, ensure that " + "the environment variable `RAY_USE_NEW_GCS=off`.") if not use_credis: assigned_port, _ = _start_redis_instance( node_ip_address=node_ip_address, @@ -477,7 +497,8 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, cleanup=cleanup, - protected_mode=protected_mode) + protected_mode=protected_mode, + password=password) else: assigned_port, _ = _start_redis_instance( node_ip_address=node_ip_address, @@ -491,20 +512,23 @@ def start_redis(node_ip_address, # It is important to load the credis module BEFORE the ray module, # as the latter contains an extern declaration that the former # supplies. - modules=[CREDIS_MASTER_MODULE, REDIS_MODULE]) + modules=[CREDIS_MASTER_MODULE, REDIS_MODULE], + password=password) if port is not None: assert assigned_port == port port = assigned_port redis_address = address(node_ip_address, port) - redis_client = redis.StrictRedis(host=node_ip_address, port=port) + redis_client = redis.StrictRedis( + host=node_ip_address, port=port, password=password) # Store whether we're using the raylet code path or not. redis_client.set("UseRaylet", 1 if use_raylet else 0) # Register the number of Redis shards in the primary shard, so that clients # know how many redis shards to expect under RedisShards. - primary_redis_client = redis.StrictRedis(host=node_ip_address, port=port) + primary_redis_client = redis.StrictRedis( + host=node_ip_address, port=port, password=password) primary_redis_client.set("NumRedisShards", str(num_redis_shards)) # Put the redirect_worker_output bool in the Redis shard so that workers @@ -529,7 +553,8 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, cleanup=cleanup, - protected_mode=protected_mode) + protected_mode=protected_mode, + password=password) else: assert num_redis_shards == 1, \ "For now, RAY_USE_NEW_GCS supports 1 shard, and credis "\ @@ -542,6 +567,7 @@ def start_redis(node_ip_address, stderr_file=redis_stderr_file, cleanup=cleanup, protected_mode=protected_mode, + password=password, executable=CREDIS_EXECUTABLE, # It is important to load the credis module BEFORE the ray # module, as the latter contains an extern declaration that the @@ -557,7 +583,7 @@ def start_redis(node_ip_address, if use_credis: shard_client = redis.StrictRedis( - host=node_ip_address, port=redis_shard_port) + host=node_ip_address, port=redis_shard_port, password=password) # Configure the chain state. primary_redis_client.execute_command("MASTER.ADD", node_ip_address, redis_shard_port) @@ -591,6 +617,7 @@ def _start_redis_instance(node_ip_address="127.0.0.1", stderr_file=None, cleanup=True, protected_mode=False, + password=None, executable=REDIS_EXECUTABLE, modules=None): """Start a single Redis server. @@ -614,6 +641,8 @@ def _start_redis_instance(node_ip_address="127.0.0.1", mode. This will prevent clients on other machines from connecting and is only used when the Redis servers are started via ray.init() as opposed to ray start. + password (str): Prevents external clients without the password + from connecting to Redis if provided. executable (str): Full path tho the redis-server executable. modules (list of str): A list of pathnames, pointing to the redis module(s) that will be loaded in this redis server. If None, load @@ -654,6 +683,8 @@ def _start_redis_instance(node_ip_address="127.0.0.1", command = [executable] if protected_mode: command += [redis_config_filename] + if password: + command += ["--requirepass", password] command += ( ["--port", str(port), "--loglevel", "warning"] + load_module_args) @@ -672,9 +703,10 @@ def _start_redis_instance(node_ip_address="127.0.0.1", stdout_file.name, stderr_file.name)) # Create a Redis client just for configuring Redis. - redis_client = redis.StrictRedis(host="127.0.0.1", port=port) + redis_client = redis.StrictRedis( + host="127.0.0.1", port=port, password=password) # Wait for the Redis server to start. - wait_for_redis_to_start("127.0.0.1", port) + wait_for_redis_to_start("127.0.0.1", port, password=password) # Configure Redis to generate keyspace notifications. TODO(rkn): Change # this to only generate notifications for the export keys. redis_client.config_set("notify-keyspace-events", "Kl") @@ -719,8 +751,9 @@ def _start_redis_instance(node_ip_address="127.0.0.1", redis_client.set("redis_start_time", time.time()) # Record the log files in Redis. record_log_files_in_redis( - address(node_ip_address, port), node_ip_address, - [stdout_file, stderr_file]) + address(node_ip_address, port), + node_ip_address, [stdout_file, stderr_file], + password=password) return port, p @@ -728,7 +761,8 @@ def start_log_monitor(redis_address, node_ip_address, stdout_file=None, stderr_file=None, - cleanup=cleanup): + cleanup=cleanup, + redis_password=None): """Start a log monitor process. Args: @@ -742,27 +776,31 @@ def start_log_monitor(redis_address, cleanup (bool): True if using Ray in local mode. If cleanup is true, then this process will be killed by services.cleanup() when the Python process that imported services exits. + redis_password (str): The password of the redis server. """ log_monitor_filepath = os.path.join( os.path.dirname(os.path.abspath(__file__)), "log_monitor.py") - p = subprocess.Popen( - [ - sys.executable, "-u", log_monitor_filepath, "--redis-address", - redis_address, "--node-ip-address", node_ip_address - ], - stdout=stdout_file, - stderr=stderr_file) + command = [ + sys.executable, "-u", log_monitor_filepath, "--redis-address", + redis_address, "--node-ip-address", node_ip_address + ] + if redis_password: + command += ["--redis-password", redis_password] + p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) if cleanup: all_processes[PROCESS_TYPE_LOG_MONITOR].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + record_log_files_in_redis( + redis_address, + node_ip_address, [stdout_file, stderr_file], + password=redis_password) def start_global_scheduler(redis_address, node_ip_address, stdout_file=None, stderr_file=None, - cleanup=True): + cleanup=True, + redis_password=None): """Start a global scheduler process. Args: @@ -776,6 +814,7 @@ def start_global_scheduler(redis_address, cleanup (bool): True if using Ray in local mode. If cleanup is true, then this process will be killed by services.cleanup() when the Python process that imported services exits. + redis_password (str): The password of the redis server. """ p = global_scheduler.start_global_scheduler( redis_address, @@ -784,8 +823,10 @@ def start_global_scheduler(redis_address, stderr_file=stderr_file) if cleanup: all_processes[PROCESS_TYPE_GLOBAL_SCHEDULER].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + record_log_files_in_redis( + redis_address, + node_ip_address, [stdout_file, stderr_file], + password=redis_password) def start_ui(redis_address, stdout_file=None, stderr_file=None, cleanup=True): @@ -911,7 +952,8 @@ def start_local_scheduler(redis_address, stderr_file=None, cleanup=True, resources=None, - num_workers=0): + num_workers=0, + redis_password=None): """Start a local scheduler process. Args: @@ -935,6 +977,7 @@ def start_local_scheduler(redis_address, quantity of that resource. num_workers (int): The number of workers that the local scheduler should start. + redis_password (str): The password of the redis server. Return: The name of the local scheduler socket. @@ -957,8 +1000,10 @@ def start_local_scheduler(redis_address, num_workers=num_workers) if cleanup: all_processes[PROCESS_TYPE_LOCAL_SCHEDULER].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + record_log_files_in_redis( + redis_address, + node_ip_address, [stdout_file, stderr_file], + password=redis_password) return local_scheduler_name @@ -973,7 +1018,8 @@ def start_raylet(redis_address, use_profiler=False, stdout_file=None, stderr_file=None, - cleanup=True): + cleanup=True, + redis_password=None): """Start a raylet, which is a combined local scheduler and object manager. Args: @@ -996,6 +1042,7 @@ def start_raylet(redis_address, cleanup (bool): True if using Ray in local mode. If cleanup is true, then this process will be killed by serices.cleanup() when the Python process that imported services exits. + redis_password (str): The password of the redis server. Returns: The raylet socket name. @@ -1029,6 +1076,8 @@ def start_raylet(redis_address, sys.executable, worker_path, node_ip_address, plasma_store_name, raylet_name, redis_address, get_temp_root())) + if redis_password: + start_worker_command += " --redis-password {}".format(redis_password) command = [ RAYLET_EXECUTABLE, @@ -1042,6 +1091,7 @@ def start_raylet(redis_address, resource_argument, start_worker_command, "", # Worker command for Java, not needed for Python. + redis_password or "", ] if use_valgrind: @@ -1063,8 +1113,10 @@ def start_raylet(redis_address, if cleanup: all_processes[PROCESS_TYPE_RAYLET].append(pid) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + record_log_files_in_redis( + redis_address, + node_ip_address, [stdout_file, stderr_file], + password=redis_password) return raylet_name @@ -1081,7 +1133,8 @@ def start_plasma_store(node_ip_address, plasma_directory=None, huge_pages=False, use_raylet=False, - plasma_store_socket_name=None): + plasma_store_socket_name=None, + redis_password=None): """This method starts an object store process. Args: @@ -1111,6 +1164,7 @@ def start_plasma_store(node_ip_address, Store with hugetlbfs support. Requires plasma_directory. use_raylet: True if the new raylet code path should be used. This is not supported yet. + redis_password (str): The password of the redis server. Return: A tuple of the Plasma store socket name, the Plasma manager socket @@ -1186,8 +1240,10 @@ def start_plasma_store(node_ip_address, if cleanup: all_processes[PROCESS_TYPE_PLASMA_STORE].append(p1) - record_log_files_in_redis(redis_address, node_ip_address, - [store_stdout_file, store_stderr_file]) + record_log_files_in_redis( + redis_address, + node_ip_address, [store_stdout_file, store_stderr_file], + password=redis_password) if not use_raylet: if cleanup: all_processes[PROCESS_TYPE_PLASMA_MANAGER].append(p2) @@ -1248,7 +1304,8 @@ def start_monitor(redis_address, stdout_file=None, stderr_file=None, cleanup=True, - autoscaling_config=None): + autoscaling_config=None, + redis_password=None): """Run a process to monitor the other processes. Args: @@ -1264,6 +1321,7 @@ def start_monitor(redis_address, Python process that imported services exits. This is True by default. autoscaling_config: path to autoscaling config file. + redis_password (str): The password of the redis server. """ monitor_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "monitor.py") @@ -1273,17 +1331,22 @@ def start_monitor(redis_address, ] if autoscaling_config: command.append("--autoscaling-config=" + str(autoscaling_config)) + if redis_password: + command.append("--redis-password=" + redis_password) p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) if cleanup: all_processes[PROCESS_TYPE_MONITOR].append(p) - record_log_files_in_redis(redis_address, node_ip_address, - [stdout_file, stderr_file]) + record_log_files_in_redis( + redis_address, + node_ip_address, [stdout_file, stderr_file], + password=redis_password) def start_raylet_monitor(redis_address, stdout_file=None, stderr_file=None, - cleanup=True): + cleanup=True, + redis_password=None): """Run a process to monitor the other processes. Args: @@ -1296,8 +1359,10 @@ def start_raylet_monitor(redis_address, then this process will be killed by services.cleanup() when the Python process that imported services exits. This is True by default. + redis_password (str): The password of the redis server. """ gcs_ip_address, gcs_port = redis_address.split(":") + redis_password = redis_password or "" command = [RAYLET_MONITOR_EXECUTABLE, gcs_ip_address, gcs_port] p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) if cleanup: @@ -1314,6 +1379,7 @@ def start_ray_processes(address_info=None, num_redis_shards=1, redis_max_clients=None, redis_protected_mode=False, + redis_password=None, worker_path=None, cleanup=True, redirect_worker_output=False, @@ -1359,6 +1425,8 @@ def start_ray_processes(address_info=None, redis_protected_mode: True if we should start Redis in protected mode. This will prevent clients from other machines from connecting and is only done when Redis is started via ray.init(). + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. worker_path (str): The path of the source code that will be run by the worker. cleanup (bool): If cleanup is true, then the processes started here @@ -1444,7 +1512,8 @@ def start_ray_processes(address_info=None, redirect_output=True, redirect_worker_output=redirect_worker_output, cleanup=cleanup, - protected_mode=redis_protected_mode) + protected_mode=redis_protected_mode, + password=redis_password) address_info["redis_address"] = redis_address time.sleep(0.1) @@ -1457,18 +1526,20 @@ def start_ray_processes(address_info=None, stdout_file=monitor_stdout_file, stderr_file=monitor_stderr_file, cleanup=cleanup, - autoscaling_config=autoscaling_config) + autoscaling_config=autoscaling_config, + redis_password=redis_password) if use_raylet: start_raylet_monitor( redis_address, stdout_file=monitor_stdout_file, stderr_file=monitor_stderr_file, - cleanup=cleanup) + cleanup=cleanup, + redis_password=redis_password) if redis_shards == []: # Get redis shards from primary redis instance. redis_ip_address, redis_port = redis_address.split(":") redis_client = redis.StrictRedis( - host=redis_ip_address, port=redis_port) + host=redis_ip_address, port=redis_port, password=redis_password) redis_shards = redis_client.lrange("RedisShards", start=0, end=-1) redis_shards = [ray.utils.decode(shard) for shard in redis_shards] address_info["redis_shards"] = redis_shards @@ -1482,7 +1553,8 @@ def start_ray_processes(address_info=None, node_ip_address, stdout_file=log_monitor_stdout_file, stderr_file=log_monitor_stderr_file, - cleanup=cleanup) + cleanup=cleanup, + redis_password=redis_password) # Start the global scheduler, if necessary. if include_global_scheduler and not use_raylet: @@ -1493,7 +1565,8 @@ def start_ray_processes(address_info=None, node_ip_address, stdout_file=global_scheduler_stdout_file, stderr_file=global_scheduler_stderr_file, - cleanup=cleanup) + cleanup=cleanup, + redis_password=redis_password) # Initialize with existing services. if "object_store_addresses" not in address_info: @@ -1537,7 +1610,8 @@ def start_ray_processes(address_info=None, plasma_directory=plasma_directory, huge_pages=huge_pages, use_raylet=use_raylet, - plasma_store_socket_name=plasma_store_socket_name) + plasma_store_socket_name=plasma_store_socket_name, + redis_password=redis_password) object_store_addresses.append(object_store_address) time.sleep(0.1) @@ -1575,7 +1649,8 @@ def start_ray_processes(address_info=None, stderr_file=local_scheduler_stderr_file, cleanup=cleanup, resources=resources[i], - num_workers=num_local_scheduler_workers) + num_workers=num_local_scheduler_workers, + redis_password=redis_password) local_scheduler_socket_names.append(local_scheduler_name) # Make sure that we have exactly num_local_schedulers instances of @@ -1599,7 +1674,8 @@ def start_ray_processes(address_info=None, num_workers=workers_per_local_scheduler[i], stdout_file=raylet_stdout_file, stderr_file=raylet_stderr_file, - cleanup=cleanup)) + cleanup=cleanup, + redis_password=redis_password)) if not use_raylet: # Start any workers that the local scheduler has not already started. @@ -1645,6 +1721,7 @@ def start_ray_node(node_ip_address, num_workers=0, num_local_schedulers=1, object_store_memory=None, + redis_password=None, worker_path=None, cleanup=True, redirect_worker_output=False, @@ -1673,6 +1750,8 @@ def start_ray_node(node_ip_address, start. object_store_memory (int): The maximum amount of memory (in bytes) to let the plasma store use. + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. worker_path (str): The path of the source code that will be run by the worker. cleanup (bool): If cleanup is true, then the processes started here @@ -1711,6 +1790,7 @@ def start_ray_node(node_ip_address, num_workers=num_workers, num_local_schedulers=num_local_schedulers, object_store_memory=object_store_memory, + redis_password=redis_password, worker_path=worker_path, include_log_monitor=True, cleanup=cleanup, @@ -1741,6 +1821,7 @@ def start_ray_head(address_info=None, num_redis_shards=None, redis_max_clients=None, redis_protected_mode=False, + redis_password=None, include_webui=True, plasma_directory=None, huge_pages=False, @@ -1792,6 +1873,8 @@ def start_ray_head(address_info=None, redis_protected_mode: True if we should start Redis in protected mode. This will prevent clients from other machines from connecting and is only done when Redis is started via ray.init(). + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. include_webui: True if the UI should be started and false otherwise. plasma_directory: A directory where the Plasma memory mapped files will be created. @@ -1832,6 +1915,7 @@ def start_ray_head(address_info=None, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, redis_protected_mode=redis_protected_mode, + redis_password=redis_password, plasma_directory=plasma_directory, huge_pages=huge_pages, autoscaling_config=autoscaling_config, diff --git a/python/ray/test/test_ray_init.py b/python/ray/test/test_ray_init.py new file mode 100644 index 000000000..62d581003 --- /dev/null +++ b/python/ray/test/test_ray_init.py @@ -0,0 +1,65 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import pytest +import redis + +import ray + + +@pytest.fixture +def password(): + random_bytes = os.urandom(128) + if hasattr(random_bytes, "hex"): + return random_bytes.hex() # Python 3 + return random_bytes.encode("hex") # Python 2 + + +@pytest.fixture +def shutdown_only(): + yield None + # The code after the yield will run as teardown code. + ray.shutdown() + + +class TestRedisPassword(object): + @pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") != "on" + and os.environ.get("RAY_USE_XRAY"), + reason="Redis authentication works for raylet and old GCS.") + def test_exceptions(self, password, shutdown_only): + with pytest.raises(Exception): + ray.init(redis_password=password) + + @pytest.mark.skipif( + os.environ.get("RAY_USE_NEW_GCS") == "on", + reason="New GCS API doesn't support Redis authentication yet.") + @pytest.mark.skipif( + not os.environ.get("RAY_USE_XRAY"), + reason="Redis authentication is not supported in legacy Ray.") + def test_redis_password(self, password, shutdown_only): + # Workaround for https://github.com/ray-project/ray/issues/3045 + @ray.remote + def f(): + return 1 + + info = ray.init(redis_password=password) + redis_address = info["redis_address"] + redis_ip, redis_port = redis_address.split(":") + + # Check that we can run a task + object_id = f.remote() + ray.get(object_id) + + # Check that Redis connections require a password + redis_client = redis.StrictRedis( + host=redis_ip, port=redis_port, password=None) + with pytest.raises(redis.ResponseError): + redis_client.ping() + + # Check that we can connect to Redis using the provided password + redis_client = redis.StrictRedis( + host=redis_ip, port=redis_port, password=password) + assert redis_client.ping() diff --git a/python/ray/worker.py b/python/ray/worker.py index 4739b2e7c..f30b46448 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1223,12 +1223,13 @@ def _initialize_serialization(driver_id, worker=global_worker): def get_address_info_from_redis_helper(redis_address, node_ip_address, - use_raylet=False): + use_raylet=False, + redis_password=None): redis_ip_address, redis_port = redis_address.split(":") # For this command to work, some other client (on the same machine as # Redis) must have run "CONFIG SET protected-mode no". redis_client = redis.StrictRedis( - host=redis_ip_address, port=int(redis_port)) + host=redis_ip_address, port=int(redis_port), password=redis_password) if not use_raylet: # The client table prefix must be kept in sync with the file @@ -1332,12 +1333,16 @@ def get_address_info_from_redis_helper(redis_address, def get_address_info_from_redis(redis_address, node_ip_address, num_retries=5, - use_raylet=False): + use_raylet=False, + redis_password=None): counter = 0 while True: try: return get_address_info_from_redis_helper( - redis_address, node_ip_address, use_raylet=use_raylet) + redis_address, + node_ip_address, + use_raylet=use_raylet, + redis_password=None) except Exception: if counter == num_retries: raise @@ -1405,6 +1410,7 @@ def _init(address_info=None, resources=None, num_redis_shards=None, redis_max_clients=None, + redis_password=None, plasma_directory=None, huge_pages=False, include_webui=True, @@ -1460,6 +1466,8 @@ def _init(address_info=None, the primary Redis shard. redis_max_clients: If provided, attempt to configure Redis with this maxclients number. + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. plasma_directory: A directory where the Plasma memory mapped files will be created. huge_pages: Boolean flag indicating whether to start the Object @@ -1544,6 +1552,7 @@ def _init(address_info=None, resources=resources, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, + redis_password=redis_password, plasma_directory=plasma_directory, huge_pages=huge_pages, include_webui=include_webui, @@ -1596,7 +1605,10 @@ def _init(address_info=None, node_ip_address = services.get_node_ip_address(redis_address) # Get the address info of the processes to connect to from Redis. address_info = get_address_info_from_redis( - redis_address, node_ip_address, use_raylet=use_raylet) + redis_address, + node_ip_address, + use_raylet=use_raylet, + redis_password=redis_password) # Connect this driver to Redis, the object store, and the local scheduler. # Choose the first object store and local scheduler if there are multiple. @@ -1628,7 +1640,8 @@ def _init(address_info=None, object_id_seed=object_id_seed, mode=driver_mode, worker=global_worker, - use_raylet=use_raylet) + use_raylet=use_raylet, + redis_password=redis_password) return address_info @@ -1647,6 +1660,7 @@ def init(redis_address=None, ignore_reinit_error=False, num_redis_shards=None, redis_max_clients=None, + redis_password=None, plasma_directory=None, huge_pages=False, include_webui=True, @@ -1709,6 +1723,8 @@ def init(redis_address=None, the primary Redis shard. redis_max_clients: If provided, attempt to configure Redis with this maxclients number. + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. plasma_directory: A directory where the Plasma memory mapped files will be created. huge_pages: Boolean flag indicating whether to start the Object @@ -1750,6 +1766,11 @@ def init(redis_address=None, # This environment variable is used in our testing setup. logger.info("Detected environment variable 'RAY_USE_XRAY'.") use_raylet = True + if not use_raylet and redis_password is not None: + raise Exception("Setting the 'redis_password' argument is not " + "supported in legacy Ray. To run Ray with " + "password-protected Redis ports, set " + "'use_raylet=True'.") # Convert hostnames to numerical IP address. if node_ip_address is not None: @@ -1772,6 +1793,7 @@ def init(redis_address=None, resources=resources, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, + redis_password=redis_password, plasma_directory=plasma_directory, huge_pages=huge_pages, include_webui=include_webui, @@ -1975,7 +1997,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, - use_raylet=False): + use_raylet=False, + redis_password=None): """Connect this worker to the local scheduler, to Plasma, and to Redis. Args: @@ -1986,6 +2009,8 @@ def connect(info, mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, and LOCAL_MODE. use_raylet: True if the new raylet code path should be used. + redis_password (str): Prevents external clients without the password + from connecting to Redis if provided. """ # Do some basic checking to make sure we didn't call ray.init twice. error_message = "Perhaps you called ray.init twice by accident?" @@ -2019,7 +2044,10 @@ def connect(info, # Create a Redis client. redis_ip_address, redis_port = info["redis_address"].split(":") worker.redis_client = thread_safe_client( - redis.StrictRedis(host=redis_ip_address, port=int(redis_port))) + redis.StrictRedis( + host=redis_ip_address, + port=int(redis_port), + password=redis_password)) # For driver's check that the version information matches the version # information that the Ray cluster was started with. @@ -2060,7 +2088,8 @@ def connect(info, [log_stdout_file, log_stderr_file]) # Create an object for interfacing with the global state. - global_state._initialize_global_state(redis_ip_address, int(redis_port)) + global_state._initialize_global_state( + redis_ip_address, int(redis_port), redis_password=redis_password) # Register the worker with Redis. if mode == SCRIPT_MODE: diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index 72679722f..670ee092d 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -24,6 +24,12 @@ parser.add_argument( required=True, type=str, help="the address to use for Redis") +parser.add_argument( + "--redis-password", + required=False, + type=str, + default=None, + help="the password to use for Redis") parser.add_argument( "--object-store-name", required=True, @@ -67,6 +73,7 @@ if __name__ == "__main__": info = { "node_ip_address": args.node_ip_address, "redis_address": args.redis_address, + "redis_password": args.redis_password, "store_socket_name": args.object_store_name, "manager_socket_name": args.object_store_manager_name, "local_scheduler_socket_name": args.local_scheduler_name, @@ -81,7 +88,10 @@ if __name__ == "__main__": tempfile_services.set_temp_root(args.temp_dir) ray.worker.connect( - info, mode=ray.WORKER_MODE, use_raylet=(args.raylet_name is not None)) + info, + mode=ray.WORKER_MODE, + use_raylet=(args.raylet_name is not None), + redis_password=args.redis_password) error_explanation = """ This error is unexpected and should not have happened. Somehow a worker diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index 182c44a8a..3acd5623a 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -71,10 +71,12 @@ namespace gcs { AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, const ClientID &client_id, CommandType command_type, - bool is_test_client = false) { + bool is_test_client = false, + const std::string &password = "") { primary_context_ = std::make_shared(); - RAY_CHECK_OK(primary_context_->Connect(address, port, /*sharding=*/true)); + RAY_CHECK_OK( + primary_context_->Connect(address, port, /*sharding=*/true, /*password=*/password)); if (!is_test_client) { // Moving sharding into constructor defaultly means that sharding = true. @@ -94,12 +96,13 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, RAY_CHECK(shard_contexts_.size() == addresses.size()); for (size_t i = 0; i < addresses.size(); ++i) { - RAY_CHECK_OK( - shard_contexts_[i]->Connect(addresses[i], ports[i], /*sharding=*/true)); + RAY_CHECK_OK(shard_contexts_[i]->Connect(addresses[i], ports[i], /*sharding=*/true, + /*password=*/password)); } } else { shard_contexts_.push_back(std::make_shared()); - RAY_CHECK_OK(shard_contexts_[0]->Connect(address, port, /*sharding=*/true)); + RAY_CHECK_OK(shard_contexts_[0]->Connect(address, port, /*sharding=*/true, + /*password=*/password)); } client_table_.reset(new ClientTable({primary_context_}, this, client_id)); @@ -126,12 +129,16 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, // Use of kChain currently only applies to Table::Add which affects only the // task table, and when RAY_USE_NEW_GCS is set at compile time. AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, - const ClientID &client_id, bool is_test_client = false) - : AsyncGcsClient(address, port, client_id, CommandType::kChain, is_test_client) {} + const ClientID &client_id, bool is_test_client = false, + const std::string &password = "") + : AsyncGcsClient(address, port, client_id, CommandType::kChain, is_test_client, + password) {} #else AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, - const ClientID &client_id, bool is_test_client = false) - : AsyncGcsClient(address, port, client_id, CommandType::kRegular, is_test_client) {} + const ClientID &client_id, bool is_test_client = false, + const std::string &password = "") + : AsyncGcsClient(address, port, client_id, CommandType::kRegular, is_test_client, + password) {} #endif // RAY_USE_NEW_GCS AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, @@ -143,8 +150,9 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, : AsyncGcsClient(address, port, ClientID::from_random(), command_type, is_test_client) {} -AsyncGcsClient::AsyncGcsClient(const std::string &address, int port) - : AsyncGcsClient(address, port, ClientID::from_random()) {} +AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, + const std::string &password = "") + : AsyncGcsClient(address, port, ClientID::from_random(), false, password) {} AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, bool is_test_client) : AsyncGcsClient(address, port, ClientID::from_random(), is_test_client) {} diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index d89aadd80..83781e841 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -31,13 +31,14 @@ class RAY_EXPORT AsyncGcsClient { /// \param command_type GCS command type. If CommandType::kChain, chain-replicated /// versions of the tables might be used, if available. AsyncGcsClient(const std::string &address, int port, const ClientID &client_id, - CommandType command_type, bool is_test_client); + CommandType command_type, bool is_test_client, + const std::string &redis_password); AsyncGcsClient(const std::string &address, int port, const ClientID &client_id, - bool is_test_client); + bool is_test_client, const std::string &password); AsyncGcsClient(const std::string &address, int port, CommandType command_type); AsyncGcsClient(const std::string &address, int port, CommandType command_type, bool is_test_client); - AsyncGcsClient(const std::string &address, int port); + AsyncGcsClient(const std::string &address, int port, const std::string &password); AsyncGcsClient(const std::string &address, int port, bool is_test_client); /// Attach this client to a plasma event loop. Note that only diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index abc06a24a..1a8111963 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -135,7 +135,30 @@ RedisContext::~RedisContext() { } } -Status RedisContext::Connect(const std::string &address, int port, bool sharding) { +Status AuthenticateRedis(redisContext *context, const std::string &password) { + if (password == "") { + return Status::OK(); + } + redisReply *reply = + reinterpret_cast(redisCommand(context, "AUTH %s", password.c_str())); + REDIS_CHECK_ERROR(context, reply); + freeReplyObject(reply); + return Status::OK(); +} + +Status AuthenticateRedis(redisAsyncContext *context, const std::string &password) { + if (password == "") { + return Status::OK(); + } + int status = redisAsyncCommand(context, NULL, NULL, "AUTH %s", password.c_str()); + if (status == REDIS_ERR) { + return Status::RedisError(std::string(context->errstr)); + } + return Status::OK(); +} + +Status RedisContext::Connect(const std::string &address, int port, bool sharding, + const std::string &password = "") { int connection_attempts = 0; context_ = redisConnect(address.c_str(), port); while (context_ == nullptr || context_->err) { @@ -155,6 +178,8 @@ Status RedisContext::Connect(const std::string &address, int port, bool sharding context_ = redisConnect(address.c_str(), port); connection_attempts += 1; } + RAY_CHECK_OK(AuthenticateRedis(context_, password)); + redisReply *reply = reinterpret_cast( redisCommand(context_, "CONFIG SET notify-keyspace-events Kl")); REDIS_CHECK_ERROR(context_, reply); @@ -166,12 +191,16 @@ Status RedisContext::Connect(const std::string &address, int port, bool sharding RAY_LOG(FATAL) << "Could not establish connection to redis " << address << ":" << port; } + RAY_CHECK_OK(AuthenticateRedis(async_context_, password)); + // Connect to subscribe context subscribe_context_ = redisAsyncConnect(address.c_str(), port); if (subscribe_context_ == nullptr || subscribe_context_->err) { RAY_LOG(FATAL) << "Could not establish subscribe connection to redis " << address << ":" << port; } + RAY_CHECK_OK(AuthenticateRedis(subscribe_context_, password)); + return Status::OK(); } diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index 67bc8197c..1fcfd55ad 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -51,7 +51,8 @@ class RedisContext { RedisContext() : context_(nullptr), async_context_(nullptr), subscribe_context_(nullptr) {} ~RedisContext(); - Status Connect(const std::string &address, int port, bool sharding); + Status Connect(const std::string &address, int port, bool sharding, + const std::string &password); Status AttachToEventLoop(aeEventLoop *loop); /// Run an operation on some table key. diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 23aa41f25..8ad70a928 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -19,7 +19,7 @@ int main(int argc, char *argv[]) { ray::RayLog::ShutDownRayLog, argv[0], RAY_INFO, /*log_dir=*/""); ray::RayLog::InstallFailureSignalHandler(); - RAY_CHECK(argc == 11); + RAY_CHECK(argc == 11 || argc == 12); const std::string raylet_socket_name = std::string(argv[1]); const std::string store_socket_name = std::string(argv[2]); @@ -31,6 +31,7 @@ int main(int argc, char *argv[]) { const std::string static_resource_list = std::string(argv[8]); const std::string python_worker_command = std::string(argv[9]); const std::string java_worker_command = std::string(argv[10]); + const std::string redis_password = (argc == 12 ? std::string(argv[11]) : ""); // Configuration for the node manager. ray::raylet::NodeManagerConfig node_manager_config; @@ -92,7 +93,8 @@ int main(int argc, char *argv[]) { << "object_chunk_size = " << object_manager_config.object_chunk_size; // initialize mock gcs & object directory - auto gcs_client = std::make_shared(redis_address, redis_port); + auto gcs_client = std::make_shared(redis_address, redis_port, + redis_password); RAY_LOG(DEBUG) << "Initializing GCS client " << gcs_client->client_table().GetLocalClientId(); @@ -100,8 +102,8 @@ int main(int argc, char *argv[]) { boost::asio::io_service main_service; ray::raylet::Raylet server(main_service, raylet_socket_name, node_ip_address, - redis_address, redis_port, node_manager_config, - object_manager_config, gcs_client); + redis_address, redis_port, redis_password, + node_manager_config, object_manager_config, gcs_client); // Destroy the Raylet on a SIGTERM. The pointer to main_service is // guaranteed to be valid since this function will run the event loop diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 05cf79309..da9d5f8ab 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -15,8 +15,8 @@ namespace raylet { /// the Ray configuration), then the monitor will mark that Raylet as dead in /// the client table, which broadcasts the event to all other Raylets. Monitor::Monitor(boost::asio::io_service &io_service, const std::string &redis_address, - int redis_port) - : gcs_client_(redis_address, redis_port), + int redis_port, const std::string &redis_password) + : gcs_client_(redis_address, redis_port, redis_password), num_heartbeats_timeout_(RayConfig::instance().num_heartbeats_timeout()), heartbeat_timer_(io_service) { RAY_CHECK_OK(gcs_client_.Attach(io_service)); diff --git a/src/ray/raylet/monitor.h b/src/ray/raylet/monitor.h index 1786bc3f1..b300bf4cf 100644 --- a/src/ray/raylet/monitor.h +++ b/src/ray/raylet/monitor.h @@ -19,7 +19,7 @@ class Monitor { /// \param redis_address The GCS Redis address to connect to. /// \param redis_port The GCS Redis port to connect to. Monitor(boost::asio::io_service &io_service, const std::string &redis_address, - int redis_port); + int redis_port, const std::string &redis_password); /// Start the monitor. Listen for heartbeats from Raylets and mark Raylets /// that do not send a heartbeat within a given period as dead. diff --git a/src/ray/raylet/monitor_main.cc b/src/ray/raylet/monitor_main.cc index 218faecd4..8cd821752 100644 --- a/src/ray/raylet/monitor_main.cc +++ b/src/ray/raylet/monitor_main.cc @@ -8,14 +8,15 @@ int main(int argc, char *argv[]) { ray::RayLog::ShutDownRayLog, argv[0], RAY_INFO, /*log_dir=*/""); ray::RayLog::InstallFailureSignalHandler(); - RAY_CHECK(argc == 3); + RAY_CHECK(argc == 3 || argc == 4); const std::string redis_address = std::string(argv[1]); int redis_port = std::stoi(argv[2]); + const std::string redis_password = (argc == 4 ? std::string(argv[3]) : ""); // Initialize the monitor. boost::asio::io_service io_service; - ray::raylet::Monitor monitor(io_service, redis_address, redis_port); + ray::raylet::Monitor monitor(io_service, redis_address, redis_port, redis_password); monitor.Start(); io_service.run(); } diff --git a/src/ray/raylet/object_manager_integration_test.cc b/src/ray/raylet/object_manager_integration_test.cc index 5fe7f774b..d714b71ac 100644 --- a/src/ray/raylet/object_manager_integration_test.cc +++ b/src/ray/raylet/object_manager_integration_test.cc @@ -60,7 +60,7 @@ class TestObjectManagerBase : public ::testing::Test { om_config_1.store_socket_name = store_sock_1; om_config_1.push_timeout_ms = 10000; server1.reset(new ray::raylet::Raylet( - main_service, "raylet_1", "0.0.0.0", "127.0.0.1", 6379, + main_service, "raylet_1", "0.0.0.0", "127.0.0.1", 6379, "", GetNodeManagerConfig("raylet_1", store_sock_1), om_config_1, gcs_client_1)); // start second server @@ -70,7 +70,7 @@ class TestObjectManagerBase : public ::testing::Test { om_config_2.store_socket_name = store_sock_2; om_config_2.push_timeout_ms = 10000; server2.reset(new ray::raylet::Raylet( - main_service, "raylet_2", "0.0.0.0", "127.0.0.1", 6379, + main_service, "raylet_2", "0.0.0.0", "127.0.0.1", 6379, "", GetNodeManagerConfig("raylet_2", store_sock_2), om_config_2, gcs_client_2)); // connect to stores. diff --git a/src/ray/raylet/raylet.cc b/src/ray/raylet/raylet.cc index df30498f4..11b54b65b 100644 --- a/src/ray/raylet/raylet.cc +++ b/src/ray/raylet/raylet.cc @@ -13,7 +13,8 @@ namespace raylet { Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_name, const std::string &node_ip_address, const std::string &redis_address, - int redis_port, const NodeManagerConfig &node_manager_config, + int redis_port, const std::string &redis_password, + const NodeManagerConfig &node_manager_config, const ObjectManagerConfig &object_manager_config, std::shared_ptr gcs_client) : gcs_client_(gcs_client), @@ -33,9 +34,9 @@ Raylet::Raylet(boost::asio::io_service &main_service, const std::string &socket_ DoAcceptObjectManager(); DoAcceptNodeManager(); - RAY_CHECK_OK(RegisterGcs(node_ip_address, socket_name_, - object_manager_config.store_socket_name, redis_address, - redis_port, main_service, node_manager_config)); + RAY_CHECK_OK(RegisterGcs( + node_ip_address, socket_name_, object_manager_config.store_socket_name, + redis_address, redis_port, redis_password, main_service, node_manager_config)); RAY_CHECK_OK(RegisterPeriodicTimer(main_service)); } @@ -52,6 +53,7 @@ ray::Status Raylet::RegisterGcs(const std::string &node_ip_address, const std::string &raylet_socket_name, const std::string &object_store_socket_name, const std::string &redis_address, int redis_port, + const std::string &redis_password, boost::asio::io_service &io_service, const NodeManagerConfig &node_manager_config) { RAY_RETURN_NOT_OK(gcs_client_->Attach(io_service)); diff --git a/src/ray/raylet/raylet.h b/src/ray/raylet/raylet.h index be634616b..9b424781a 100644 --- a/src/ray/raylet/raylet.h +++ b/src/ray/raylet/raylet.h @@ -29,6 +29,7 @@ class Raylet { /// \param node_ip_address The IP address of this node. /// \param redis_address The IP address of the redis instance we are connecting to. /// \param redis_port The port of the redis instance we are connecting to. + /// \param redis_password The password of the redis instance we are connecting to. /// \param node_manager_config Configuration to initialize the node manager. /// scheduler with. /// \param object_manager_config Configuration to initialize the object @@ -36,7 +37,8 @@ class Raylet { /// \param gcs_client A client connection to the GCS. Raylet(boost::asio::io_service &main_service, const std::string &socket_name, const std::string &node_ip_address, const std::string &redis_address, - int redis_port, const NodeManagerConfig &node_manager_config, + int redis_port, const std::string &redis_password, + const NodeManagerConfig &node_manager_config, const ObjectManagerConfig &object_manager_config, std::shared_ptr gcs_client); @@ -49,6 +51,7 @@ class Raylet { const std::string &raylet_socket_name, const std::string &object_store_socket_name, const std::string &redis_address, int redis_port, + const std::string &redis_password, boost::asio::io_service &io_service, const NodeManagerConfig &); ray::Status RegisterPeriodicTimer(boost::asio::io_service &io_service);