diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index d6093b612..2d1603f88 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -81,6 +81,12 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, # temporary fix. We should actually redirect stdout and stderr to Redis in # some way. + # Convert hostnames to numerical IP address. + if node_ip_address is not None: + node_ip_address = services.address_to_ip(node_ip_address) + if redis_address is not None: + redis_address = services.address_to_ip(redis_address) + if head: # Start Ray on the head node. if redis_address is not None: diff --git a/python/ray/services.py b/python/ray/services.py index cf845a98b..424b1009c 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -165,6 +165,25 @@ def all_processes_alive(exclude=[]): return True +def address_to_ip(address): + """Convert a hostname to a numerical IP addresses in an address. + + This should be a no-op if address already contains an actual numerical IP + address. + + Args: + address: This can be either a string containing a hostname (or an IP + address) and a port or it can be just an IP address. + + Returns: + The same address but with the hostname replaced by a numerical IP + address. + """ + address_parts = address.split(":") + ip_address = socket.gethostbyname(address_parts[0]) + return ":".join([ip_address] + address_parts[1:]) + + def get_node_ip_address(address="8.8.8.8:53"): """Determine the IP address of the local node. diff --git a/python/ray/worker.py b/python/ray/worker.py index f5ed90e1d..9938a247a 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1359,6 +1359,12 @@ def init(redis_address=None, node_ip_address=None, object_id_seed=None, Exception: An exception is raised if an inappropriate combination of arguments is passed in. """ + # Convert hostnames to numerical IP address. + if node_ip_address is not None: + node_ip_address = services.address_to_ip(node_ip_address) + if redis_address is not None: + redis_address = services.address_to_ip(redis_address) + info = {"node_ip_address": node_ip_address, "redis_address": redis_address} return _init(address_info=info, start_ray_local=(redis_address is None), diff --git a/test/multi_node_test.py b/test/multi_node_test.py index 952e08da0..fc84ffd71 100644 --- a/test/multi_node_test.py +++ b/test/multi_node_test.py @@ -208,6 +208,24 @@ class StartRayScriptTest(unittest.TestCase): "--redis-address", "127.0.0.1:6379"]) subprocess.Popen(["ray", "stop"]).wait() + def testUsingHostnames(self): + # Start the Ray processes on this machine. + subprocess.check_output( + ["ray", "start", "--head", + "--node-ip-address=localhost", + "--redis-port=6379"]).decode("ascii") + + ray.init(node_ip_address="localhost", redis_address="localhost:6379") + + @ray.remote + def f(): + return 1 + + self.assertEqual(ray.get(f.remote()), 1) + + # Kill the Ray cluster. + subprocess.Popen(["ray", "stop"]).wait() + if __name__ == "__main__": unittest.main(verbosity=2)