diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 5f25b29e0..b17192a6e 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -236,15 +236,11 @@ def start(node_ip_address, redis_address, address, redis_port, # Convert hostnames to numerical IP address. if node_ip_address is not None: node_ip_address = services.address_to_ip(node_ip_address) - if redis_address is not None: - redis_address = services.address_to_ip(redis_address) - if address: - if redis_address: - raise ValueError( - "You should specify address instead of redis_address.") - if address == "auto": - address = services.find_redis_address_or_die() - redis_address = address + + if redis_address is not None or address is not None: + (redis_address, redis_address_ip, + redis_address_port) = services.validate_redis_address( + address, redis_address) try: resources = json.loads(resources) @@ -339,10 +335,10 @@ def start(node_ip_address, redis_address, address, redis_port, # Start Ray on a non-head node. if redis_port is not None: raise Exception("If --head is not passed in, --redis-port is not " - "allowed") + "allowed.") if redis_shard_ports is not None: raise Exception("If --head is not passed in, --redis-shard-ports " - "is not allowed") + "is not allowed.") if redis_address is None: raise Exception("If --head is not passed in, --redis-address must " "be provided.") @@ -359,12 +355,10 @@ def start(node_ip_address, redis_address, address, redis_port, raise ValueError("--include-java should only be set for the head " "node.") - redis_ip_address, redis_port = redis_address.split(":") - # Wait for the Redis server to be started. And throw an exception if we # can't connect to it. services.wait_for_redis_to_start( - redis_ip_address, int(redis_port), password=redis_password) + redis_address_ip, redis_address_port, password=redis_password) # Create a Redis client. redis_client = services.create_redis_client( diff --git a/python/ray/services.py b/python/ray/services.py index 598b7f360..69ee0dc70 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -203,6 +203,51 @@ def remaining_processes_alive(): return ray.worker._global_node.remaining_processes_alive() +def validate_redis_address(address, redis_address): + """Validates redis address parameter and splits it into host/ip components. + + We temporarily support both 'address' and 'redis_address', so both are + handled here. + + Returns: + redis_address: string containing the full address. + redis_ip: string representing the host portion of the address. + redis_port: integer representing the port portion of the address. + + Raises: + ValueError: if both address and redis_address were specified or the + address was malformed. + """ + + if redis_address == "auto": + raise ValueError("auto address resolution not supported for " + "redis_address parameter. Please use address.") + + if address: + if redis_address: + raise ValueError( + "Both address and redis_address specified. Use only address.") + if address == "auto": + address = find_redis_address_or_die() + redis_address = address + + redis_address = address_to_ip(redis_address) + + redis_address_parts = redis_address.split(":") + if len(redis_address_parts) != 2: + raise ValueError("Malformed address. Expected ':'.") + redis_ip = redis_address_parts[0] + try: + redis_port = int(redis_address_parts[1]) + except ValueError: + raise ValueError("Malformed address port. Must be an integer.") + if redis_port < 1024 or redis_port > 65535: + raise ValueError("Invalid address port. Must " + "be between 1024 and 65535.") + + return redis_address, redis_ip, redis_port + + def address_to_ip(address): """Convert a hostname to a numerical IP addresses in an address. diff --git a/python/ray/worker.py b/python/ray/worker.py index 05b2c3862..6047d0d7b 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1405,13 +1405,9 @@ def init(address=None, arguments is passed in. """ - if address: - if redis_address: - raise ValueError( - "You should specify address instead of redis_address.") - if address == "auto": - address = services.find_redis_address_or_die() - redis_address = address + if redis_address is not None or address is not None: + redis_address, _, _ = services.validate_redis_address( + address, redis_address) if configure_logging: setup_logger(logging_level, logging_format) @@ -1441,8 +1437,6 @@ def init(address=None, # Convert hostnames to numerical IP address. if node_ip_address is not None: node_ip_address = services.address_to_ip(node_ip_address) - if redis_address is not None: - redis_address = services.address_to_ip(redis_address) global _global_node if driver_mode == LOCAL_MODE: