diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 4b31edf98..460f598a0 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -8,7 +8,6 @@ import json import logging import os import subprocess -import sys import ray.services as services from ray.autoscaler.commands import ( @@ -74,6 +73,8 @@ def cli(logging_level, logging_format): required=False, type=str, help="the address to use for connecting to Redis") +@click.option( + "--address", required=False, type=str, help="same as --redis-address") @click.option( "--redis-port", required=False, @@ -216,12 +217,12 @@ def cli(logging_level, logging_format): is_flag=True, default=False, help="Specify whether load code from local file or GCS serialization.") -def start(node_ip_address, redis_address, redis_port, num_redis_shards, - redis_max_clients, redis_password, redis_shard_ports, - object_manager_port, node_manager_port, object_store_memory, - redis_max_memory, num_cpus, num_gpus, resources, head, include_webui, - block, plasma_directory, huge_pages, autoscaling_config, - no_redirect_worker_output, no_redirect_output, +def start(node_ip_address, redis_address, address, redis_port, + num_redis_shards, redis_max_clients, redis_password, + redis_shard_ports, object_manager_port, node_manager_port, + object_store_memory, redis_max_memory, num_cpus, num_gpus, resources, + head, include_webui, block, plasma_directory, huge_pages, + autoscaling_config, no_redirect_worker_output, no_redirect_output, plasma_store_socket_name, raylet_socket_name, temp_dir, include_java, java_worker_options, load_code_from_local, internal_config): # Convert hostnames to numerical IP address. @@ -229,6 +230,13 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, node_ip_address = services.address_to_ip(node_ip_address) if redis_address is not None: redis_address = services.address_to_ip(redis_address) + if address: + if redis_address: + raise ValueError( + "You should specify address instead of redis_address.") + if address == "auto": + address = services.find_redis_address_or_die() + redis_address = address try: resources = json.loads(resources) @@ -741,33 +749,7 @@ done help="Override the redis address to connect to.") def timeline(redis_address): if not redis_address: - import psutil - pids = psutil.pids() - redis_addresses = set() - for pid in pids: - try: - proc = psutil.Process(pid) - for arglist in proc.cmdline(): - for arg in arglist.split(" "): - if arg.startswith("--redis-address="): - addr = arg.split("=")[1] - redis_addresses.add(addr) - except psutil.AccessDenied: - pass - except psutil.NoSuchProcess: - pass - if len(redis_addresses) > 1: - logger.info( - "Found multiple active Ray instances: {}. ".format( - redis_addresses) + - "Please specify the one to connect to with --redis-address.") - sys.exit(1) - elif not redis_addresses: - logger.info( - "Could not find any running Ray instance. " - "Please specify the one to connect to with --redis-address.") - sys.exit(1) - redis_address = redis_addresses.pop() + redis_address = services.find_redis_address_or_die() logger.info("Connecting to Ray instance at {}.".format(redis_address)) ray.init(redis_address=redis_address) time = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") diff --git a/python/ray/services.py b/python/ray/services.py index c194da18f..e21b66fd4 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -93,6 +93,38 @@ def include_java_from_redis(redis_client): return redis_client.get("INCLUDE_JAVA") == b"1" +def find_redis_address_or_die(): + try: + import psutil + except ImportError: + raise ImportError( + "Please install `psutil` to automatically detect the Ray cluster.") + pids = psutil.pids() + redis_addresses = set() + for pid in pids: + try: + proc = psutil.Process(pid) + for arglist in proc.cmdline(): + for arg in arglist.split(" "): + if arg.startswith("--redis-address="): + addr = arg.split("=")[1] + redis_addresses.add(addr) + except psutil.AccessDenied: + pass + except psutil.NoSuchProcess: + pass + if len(redis_addresses) > 1: + raise ConnectionError( + "Found multiple active Ray instances: {}. ".format(redis_addresses) + + "Please specify the one to connect to by setting `address`.") + sys.exit(1) + elif not redis_addresses: + raise ConnectionError( + "Could not find any running Ray instance. " + "Please specify the one to connect to by setting `address`.") + return redis_addresses.pop() + + def get_address_info_from_redis_helper(redis_address, node_ip_address, redis_password=None): diff --git a/python/ray/worker.py b/python/ray/worker.py index abd42a908..19381c313 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1262,6 +1262,7 @@ def _initialize_serialization(job_id, worker=global_worker): def init(redis_address=None, + address=None, num_cpus=None, num_gpus=None, resources=None, @@ -1313,6 +1314,7 @@ def init(redis_address=None, this address is not provided, then this command will start Redis, a raylet, a plasma store, a plasma manager, and some workers. It will also kill these processes when Python exits. + address (str): Same as redis_address. num_cpus (int): Number of cpus the user wishes all raylets to be configured with. num_gpus (int): Number of gpus the user wishes all raylets to @@ -1376,6 +1378,14 @@ def init(redis_address=None, arguments is passed in. """ + if address: + if redis_address: + raise ValueError( + "You should specify address instead of redis_address.") + if address == "auto": + address = services.find_redis_address_or_die() + redis_address = address + if configure_logging: setup_logger(logging_level, logging_format) diff --git a/rllib/train.py b/rllib/train.py index af4d62b0c..16096d3ec 100755 --- a/rllib/train.py +++ b/rllib/train.py @@ -37,7 +37,7 @@ def create_parser(parser_creator=None): # See also the base parser definition in ray/tune/config_parser.py parser.add_argument( - "--redis-address", + "--ray-address", default=None, type=str, help="Connect to an existing Ray cluster at this address instead " @@ -144,10 +144,10 @@ def run(args, parser): num_gpus=args.ray_num_gpus or 0, object_store_memory=args.ray_object_store_memory, redis_max_memory=args.ray_redis_max_memory) - ray.init(redis_address=cluster.redis_address) + ray.init(address=cluster.redis_address) else: ray.init( - redis_address=args.redis_address, + address=args.ray_address, object_store_memory=args.ray_object_store_memory, redis_max_memory=args.ray_redis_max_memory, num_cpus=args.ray_num_cpus,