diff --git a/python/ray/tests/test_advanced_3.py b/python/ray/tests/test_advanced_3.py index 2776a14ab..340b427f8 100644 --- a/python/ray/tests/test_advanced_3.py +++ b/python/ray/tests/test_advanced_3.py @@ -689,6 +689,32 @@ def test_lease_request_leak(shutdown_only): assert len(ray.objects()) == 0, ray.objects() +@pytest.mark.parametrize( + "ray_start_cluster", [{ + "num_cpus": 0, + "num_nodes": 1, + "do_init": False, + }], + indirect=True) +def test_ray_address_environment_variable(ray_start_cluster): + address = ray_start_cluster.address + # In this test we use zero CPUs to distinguish between starting a local + # ray cluster and connecting to an existing one. + + # Make sure we connect to an existing cluster if + # RAY_ADDRESS is set. + os.environ["RAY_ADDRESS"] = address + ray.init() + assert "CPU" not in ray.state.cluster_resources() + del os.environ["RAY_ADDRESS"] + ray.shutdown() + + # Make sure we start a new cluster if RAY_ADDRESS is not set. + ray.init() + assert "CPU" in ray.state.cluster_resources() + ray.shutdown() + + if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/util/joblib/ray_backend.py b/python/ray/util/joblib/ray_backend.py index 2d6618762..f4128b572 100644 --- a/python/ray/util/joblib/ray_backend.py +++ b/python/ray/util/joblib/ray_backend.py @@ -5,8 +5,6 @@ import logging from ray.util.multiprocessing.pool import Pool import ray -RAY_ADDRESS_ENV = "RAY_ADDRESS" - logger = logging.getLogger(__name__) @@ -34,15 +32,13 @@ class RayBackend(MultiprocessingBackend): if n_jobs == -1: if not ray.is_initialized(): import os - if RAY_ADDRESS_ENV in os.environ: - ray_address = os.environ[RAY_ADDRESS_ENV] + if "RAY_ADDRESS" in os.environ: logger.info( "Connecting to ray cluster at address='{}'".format( - ray_address)) - ray.init(address=ray_address) + os.environ["RAY_ADDRESS"])) else: logger.info("Starting local ray cluster") - ray.init() + ray.init() ray_cpus = int(ray.state.cluster_resources()["CPU"]) n_jobs = ray_cpus diff --git a/python/ray/util/multiprocessing/pool.py b/python/ray/util/multiprocessing/pool.py index cf3ce37e2..288f54302 100644 --- a/python/ray/util/multiprocessing/pool.py +++ b/python/ray/util/multiprocessing/pool.py @@ -349,11 +349,12 @@ class Pool: # Else, the priority is: # ray_address argument > RAY_ADDRESS > start new local cluster. if not ray.is_initialized(): - if ray_address is None and RAY_ADDRESS_ENV in os.environ: - ray_address = os.environ[RAY_ADDRESS_ENV] - # Cluster mode. - if ray_address is not None: + if ray_address is None and RAY_ADDRESS_ENV in os.environ: + logger.info("Connecting to ray cluster at address='{}'".format( + os.environ[RAY_ADDRESS_ENV])) + ray.init() + elif ray_address is not None: logger.info("Connecting to ray cluster at address='{}'".format( ray_address)) ray.init(address=ray_address) diff --git a/python/ray/worker.py b/python/ray/worker.py index 4c4ecb6f0..a4a7ecd7c 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -565,6 +565,10 @@ def init(address=None, ray.init(address="123.45.67.89:6379") + You can also define an environment variable called `RAY_ADDRESS` in + the same format as the `address` parameter to connect to an existing + cluster with ray.init(). + Args: address (str): The address of the Ray cluster to connect to. If this address is not provided, then this command will start Redis, @@ -672,6 +676,17 @@ def init(address=None, raise DeprecationWarning("The redis_address argument is deprecated. " "Please use address instead.") + if "RAY_ADDRESS" in os.environ: + if redis_address is None and (address is None or address == "auto"): + address = os.environ["RAY_ADDRESS"] + else: + raise RuntimeError( + "Cannot use both the RAY_ADDRESS environment variable and " + "the address argument of ray.init simultaneously. If you " + "use RAY_ADDRESS to connect to a specific Ray cluster, " + "please call ray.init() or ray.init(address=\"auto\") on the " + "driver.") + if redis_address is not None or address is not None: redis_address, _, _ = services.validate_redis_address( address, redis_address)