diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index 62aba7a45..e8e32fbb3 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -69,7 +69,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { manager = new RunManager(rayConfig); manager.startRayProcesses(true); } - redisClient = new RedisClient(rayConfig.getRedisAddress()); + redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword); // TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis. objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName); diff --git a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java index 7a22b07b2..bf5582be2 100644 --- a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java @@ -47,6 +47,8 @@ public class RayConfig { private Integer redisPort; public final int headRedisPort; public final int numberRedisShards; + public final String headRedisPassword; + public final String redisPassword; public final String objectStoreSocketName; public final Long objectStoreSize; @@ -157,6 +159,8 @@ public class RayConfig { } headRedisPort = config.getInt("ray.redis.head-port"); numberRedisShards = config.getInt("ray.redis.shard-number"); + headRedisPassword = config.getString("ray.redis.head-password"); + redisPassword = config.getString("ray.redis.password"); // object store configurations objectStoreSocketName = config.getString("ray.object-store.socket-name"); diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java index 4997bb55a..4aa2c9607 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java @@ -1,6 +1,8 @@ package org.ray.runtime.gcs; import java.util.Map; + +import org.ray.runtime.util.StringUtil; import redis.clients.jedis.Jedis; import redis.clients.jedis.JedisPool; import redis.clients.jedis.JedisPoolConfig; @@ -15,6 +17,10 @@ public class RedisClient { private JedisPool jedisPool; public RedisClient(String redisAddress) { + this(redisAddress, null); + } + + public RedisClient(String redisAddress, String password) { String[] ipAndPort = redisAddress.split(":"); if (ipAndPort.length != 2) { throw new IllegalArgumentException("The argument redisAddress " + @@ -23,8 +29,14 @@ public class RedisClient { JedisPoolConfig jedisPoolConfig = new JedisPoolConfig(); jedisPoolConfig.setMaxTotal(JEDIS_POOL_SIZE); - jedisPool = new JedisPool(jedisPoolConfig, ipAndPort[0], - Integer.parseInt(ipAndPort[1]), 30000); + + if (StringUtil.isNullOrEmpty(password)) { + jedisPool = new JedisPool(jedisPoolConfig, + ipAndPort[0], Integer.parseInt(ipAndPort[1]), 30000); + } else { + jedisPool = new JedisPool(jedisPoolConfig, ipAndPort[0], + Integer.parseInt(ipAndPort[1]), 30000, password); + } } public Long set(final String key, final String value, final String field) { diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java index d2e72c646..347ec3388 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java @@ -2,6 +2,7 @@ package org.ray.runtime.runner; import com.google.common.base.Joiner; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import java.io.File; import java.io.IOException; import java.time.LocalDateTime; @@ -16,6 +17,7 @@ import java.util.stream.Stream; import org.ray.runtime.config.RayConfig; import org.ray.runtime.util.FileUtil; import org.ray.runtime.util.ResourceUtil; +import org.ray.runtime.util.StringUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import redis.clients.jedis.Jedis; @@ -146,9 +148,13 @@ public class RunManager { private void startRedisServer() { // start primary redis - String primary = startRedisInstance(rayConfig.nodeIp, rayConfig.headRedisPort, null); + String primary = startRedisInstance(rayConfig.nodeIp, + rayConfig.headRedisPort, rayConfig.headRedisPassword, null); rayConfig.setRedisAddress(primary); try (Jedis client = new Jedis("127.0.0.1", rayConfig.headRedisPort)) { + if (!StringUtil.isNullOrEmpty(rayConfig.headRedisPassword)) { + client.auth(rayConfig.headRedisPassword); + } client.set("UseRaylet", "1"); // Register the number of Redis shards in the primary shard, so that clients // know how many redis shards to expect under RedisShards. @@ -156,14 +162,15 @@ public class RunManager { // start redis shards for (int i = 0; i < rayConfig.numberRedisShards; i++) { - String shard = startRedisInstance(rayConfig.nodeIp, rayConfig.headRedisPort + i + 1, i); + String shard = startRedisInstance(rayConfig.nodeIp, + rayConfig.headRedisPort + i + 1, rayConfig.headRedisPassword, i); client.rpush("RedisShards", shard); } } } - private String startRedisInstance(String ip, int port, Integer shard) { - List command = ImmutableList.of( + private String startRedisInstance(String ip, int port, String password, Integer shard) { + List command = Lists.newArrayList( rayConfig.redisServerExecutablePath, "--protected-mode", "no", @@ -174,10 +181,20 @@ public class RunManager { "--loadmodule", rayConfig.redisModulePath ); + + if (!StringUtil.isNullOrEmpty(password)) { + command.add("--requirepass "); + command.add(password); + } + String name = shard == null ? "redis" : "redis-" + shard; startProcess(command, null, name); try (Jedis client = new Jedis("127.0.0.1", port)) { + if (!StringUtil.isNullOrEmpty(password)) { + client.auth(password); + } + // Configure Redis to only generate notifications for the export keys. client.configSet("notify-keyspace-events", "Kl"); // Put a time stamp in Redis to indicate when it was started. @@ -192,6 +209,11 @@ public class RunManager { int maximumStartupConcurrency = Math.max(1, Math.min(rayConfig.resources.getOrDefault("CPU", 0.0).intValue(), hardwareConcurrency)); + String redisPasswordOption = ""; + if (!StringUtil.isNullOrEmpty(rayConfig.headRedisPassword)) { + redisPasswordOption = rayConfig.headRedisPassword; + } + // See `src/ray/raylet/main.cc` for the meaning of each parameter. List command = ImmutableList.of( rayConfig.rayletExecutablePath, @@ -207,7 +229,8 @@ public class RunManager { ResourceUtil.getResourcesStringFromMap(rayConfig.resources), String.join(",", rayConfig.rayletConfigParameters), // The internal config list. buildPythonWorkerCommand(), // python worker command - buildWorkerCommandRaylet() // java worker command + buildWorkerCommandRaylet(), // java worker command + redisPasswordOption ); startProcess(command, null, "raylet"); @@ -248,6 +271,11 @@ public class RunManager { // Config overwrite cmd.add("-Dray.redis.address=" + rayConfig.getRedisAddress()); + // redis password + if (!StringUtil.isNullOrEmpty(rayConfig.headRedisPassword)) { + cmd.add("-Dray.redis.password=" + rayConfig.headRedisPassword); + } + cmd.addAll(rayConfig.jvmParameters); // Main class diff --git a/java/runtime/src/main/resources/ray.default.conf b/java/runtime/src/main/resources/ray.default.conf index f863300ae..6a3f95a01 100644 --- a/java/runtime/src/main/resources/ray.default.conf +++ b/java/runtime/src/main/resources/ray.default.conf @@ -64,6 +64,10 @@ ray { address: "" // If `redis.server` isn't provided, which port we should use to start redis server. head-port: 6379 + // The password used to start the redis server on the head node. + head-password: "" + // The password used to connect to the redis server. + password:"" // If `redis.server` isn't provided, how many Redis shards we should start in addition to the // primary Redis shard. The ports of these shards will be `head-port + 1`, `head-port + 2`, etc. shard-number: 1 diff --git a/java/test/src/main/java/org/ray/api/test/BaseTest.java b/java/test/src/main/java/org/ray/api/test/BaseTest.java index d44ae348e..22c35670a 100644 --- a/java/test/src/main/java/org/ray/api/test/BaseTest.java +++ b/java/test/src/main/java/org/ray/api/test/BaseTest.java @@ -11,6 +11,7 @@ public class BaseTest { public void setUp() { System.setProperty("ray.home", "../.."); System.setProperty("ray.resources", "CPU:4,RES-A:4"); + beforeInitRay(); Ray.init(); } @@ -20,6 +21,7 @@ public class BaseTest { // We could not enable this until `systemInfo` enabled. //File rayletSocketFIle = new File(Ray.systemInfo().rayletSocketName()); Ray.shutdown(); + afterShutdownRay(); //remove raylet socket file //rayletSocketFIle.delete(); @@ -29,4 +31,11 @@ public class BaseTest { System.clearProperty("ray.resources"); } + protected void beforeInitRay() { + + } + + protected void afterShutdownRay() { + + } } diff --git a/java/test/src/main/java/org/ray/api/test/RedisPasswordTest.java b/java/test/src/main/java/org/ray/api/test/RedisPasswordTest.java new file mode 100644 index 000000000..210a4a045 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/RedisPasswordTest.java @@ -0,0 +1,34 @@ +package org.ray.api.test; + +import org.ray.api.Ray; +import org.ray.api.RayObject; +import org.ray.api.annotation.RayRemote; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class RedisPasswordTest extends BaseTest { + + @Override + public void beforeInitRay() { + System.setProperty("ray.redis.head-password", "12345678"); + System.setProperty("ray.redis.password", "12345678"); + } + + @Override + public void afterShutdownRay() { + System.clearProperty("ray.redis.head-password"); + System.clearProperty("ray.redis.password"); + } + + @RayRemote + public static String echo(String str) { + return str; + } + + @Test + public void testRedisPassword() { + RayObject obj = Ray.call(RedisPasswordTest::echo, "hello"); + Assert.assertEquals("hello", obj.get()); + } + +} diff --git a/python/ray/services.py b/python/ray/services.py index 928105d3a..7487bd5b6 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1029,7 +1029,12 @@ def start_raylet(redis_address, java_worker_options = (java_worker_options or DEFAULT_JAVA_WORKER_OPTIONS) java_worker_command = build_java_worker_command( - java_worker_options, redis_address, plasma_store_name, raylet_name) + java_worker_options, + redis_address, + plasma_store_name, + raylet_name, + redis_password, + ) else: java_worker_command = "" @@ -1086,8 +1091,13 @@ def start_raylet(redis_address, return process_info -def build_java_worker_command(java_worker_options, redis_address, - plasma_store_name, raylet_name): +def build_java_worker_command( + java_worker_options, + redis_address, + plasma_store_name, + raylet_name, + redis_password, +): """This method assembles the command used to start a Java worker. Args: @@ -1096,7 +1106,7 @@ def build_java_worker_command(java_worker_options, redis_address, plasma_store_name (str): The name of the plasma store socket to connect to. raylet_name (str): The name of the raylet socket to create. - + redis_password (str): The password of connect to redis. Returns: The command string for starting Java worker. """ @@ -1113,6 +1123,9 @@ def build_java_worker_command(java_worker_options, redis_address, if raylet_name is not None: command += "-Dray.raylet.socket-name={} ".format(raylet_name) + if redis_password is not None: + command += ("-Dray.redis-password=%s", redis_password) + command += "-Dray.home={} ".format(RAY_HOME) command += "-Dray.log-dir={} ".format(get_logs_dir_path()) command += "org.ray.runtime.runner.worker.DefaultWorker"