[Java] Simplify Ray.init() by invoking ray start internally (#10762)

This commit is contained in:
Kai Yang
2020-12-04 14:33:45 +08:00
committed by GitHub
parent 8cebe1e79c
commit 21fcee28f9
39 changed files with 367 additions and 1085 deletions
+1 -22
View File
@@ -70,6 +70,7 @@ define_java_module(
visibility = ["//visibility:public"],
deps = [
":io_ray_ray_api",
"@maven//:com_google_code_gson_gson",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:com_typesafe_config",
@@ -134,27 +135,12 @@ filegroup(
],
)
native_java_binary("runtime", "raylet", "//:raylet")
native_java_binary("runtime", "plasma_store_server", "//:plasma_store_server")
native_java_binary("runtime", "redis-server", "//:redis-server")
native_java_binary("runtime", "gcs_server", "//:gcs_server")
native_java_binary("runtime", "libray_redis_module.so", "//:libray_redis_module.so")
native_java_library("runtime", "core_worker_library_java", "//:libcore_worker_library_java.so")
filegroup(
name = "java_native_deps",
srcs = [
":core_worker_library_java",
":gcs_server",
":libray_redis_module.so",
":plasma_store_server",
":raylet",
":redis-server",
],
)
@@ -252,13 +238,6 @@ genrule(
WORK_DIR="$$(pwd)"
rm -rf "$$WORK_DIR/python/ray/jars" && mkdir -p "$$WORK_DIR/python/ray/jars"
cp -f $(location //java:ray_dist_deploy.jar) "$$WORK_DIR/python/ray/jars/ray_dist.jar"
chmod +w "$$WORK_DIR/python/ray/jars/ray_dist.jar"
zip -d "$$WORK_DIR/python/ray/jars/ray_dist.jar" \
"native/*/gcs_server" \
"native/*/libray_redis_module.so" \
"native/*/plasma_store_server" \
"native/*/raylet" \
"native/*/redis-server"
date > $@
""",
local = 1,
@@ -28,16 +28,6 @@ public interface RuntimeContext {
*/
boolean wasCurrentActorRestarted();
/**
* Get the raylet socket name.
*/
String getRayletSocketName();
/**
* Get the object store socket name.
*/
String getObjectStoreSocketName();
/**
* Return true if Ray is running in single-process mode, false if Ray is running in cluster mode.
*/
+5
View File
@@ -39,6 +39,11 @@
<artifactId>ray-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.5</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
@@ -16,7 +16,7 @@ public class DefaultRayRuntimeFactory implements RayRuntimeFactory {
@Override
public RayRuntime createRayRuntime() {
RayConfig rayConfig = RayConfig.getInstance();
RayConfig rayConfig = RayConfig.create();
LoggingUtil.setupLogging(rayConfig);
Logger logger = LoggerFactory.getLogger(DefaultRayRuntimeFactory.class);
@@ -57,7 +57,6 @@ public class RayDevRuntime extends AbstractRayRuntime {
taskSubmitter = null;
}
taskExecutor = null;
RayConfig.reset();
}
@Override
@@ -5,10 +5,8 @@ import io.ray.api.BaseActorHandle;
import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
import io.ray.api.id.UniqueId;
import io.ray.api.runtimecontext.NodeInfo;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.context.NativeWorkerContext;
import io.ray.runtime.exception.RayException;
import io.ray.runtime.exception.RayIntentionalSystemExitException;
import io.ray.runtime.gcs.GcsClient;
import io.ray.runtime.gcs.GcsClientOptions;
@@ -22,15 +20,12 @@ import io.ray.runtime.task.NativeTaskSubmitter;
import io.ray.runtime.task.TaskExecutor;
import io.ray.runtime.util.BinaryFileUtil;
import io.ray.runtime.util.JniUtils;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -41,7 +36,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
private static final Logger LOGGER = LoggerFactory.getLogger(RayNativeRuntime.class);
private RunManager manager = null;
private boolean startRayHead = false;
/**
* In Java, GC runs in a standalone thread, and we can't control the exact
@@ -52,124 +47,101 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
*/
private final ReadWriteLock shutdownLock = new ReentrantReadWriteLock();
public RayNativeRuntime(RayConfig rayConfig) {
super(rayConfig);
}
static {
LOGGER.debug("Loading native libraries.");
// Expose ray ABI symbols which may be depended by other shared
// libraries such as libstreaming_java.so.
// See BUILD.bazel:libcore_worker_library_java.so
final RayConfig rayConfig = RayConfig.getInstance();
if (rayConfig.getRedisAddress() != null && rayConfig.workerMode == WorkerType.DRIVER) {
// Fetch session dir from GCS if this is a driver that is connecting to the existing GCS.
private void updateSessionDir() {
if (rayConfig.workerMode == WorkerType.DRIVER) {
// Fetch session dir from GCS if this is a driver.
RedisClient client = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
final String sessionDir = client.get("session_dir", null);
Preconditions.checkNotNull(sessionDir);
rayConfig.setSessionDir(sessionDir);
}
JniUtils.loadLibrary(BinaryFileUtil.CORE_WORKER_JAVA_LIBRARY, true);
LOGGER.debug("Native libraries loaded.");
try {
FileUtils.forceMkdir(new File(rayConfig.logDir));
} catch (IOException e) {
throw new RuntimeException("Failed to create the log directory.", e);
}
}
public RayNativeRuntime(RayConfig rayConfig) {
super(rayConfig);
loadConfigFromGcs(rayConfig);
}
private static void loadConfigFromGcs(RayConfig rayConfig) {
if (rayConfig.getRedisAddress() != null) {
GcsClient tempGcsClient =
new GcsClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
for (Map.Entry<String, String> entry :
tempGcsClient.getInternalConfig().entrySet()) {
rayConfig.rayletConfigParameters.put(entry.getKey(), entry.getValue());
}
if (rayConfig.workerMode == WorkerType.DRIVER) {
// Keep this method logic in sync with `services.get_address_info_from_redis_helper`
int numRetries = 5;
int retryCount = 0;
boolean configLoaded = false;
while (retryCount++ < numRetries) {
for (NodeInfo nodeInfo : tempGcsClient.getAllNodeInfo()) {
if (rayConfig.nodeIp.equals(nodeInfo.nodeAddress) ||
(nodeInfo.nodeAddress.equals("127.0.0.1") &&
rayConfig.nodeIp.equals(rayConfig.getRedisAddress()))) {
rayConfig.objectStoreSocketName = nodeInfo.objectStoreSocketName;
rayConfig.rayletSocketName = nodeInfo.rayletSocketName;
rayConfig.nodeManagerPort = nodeInfo.nodeManagerPort;
configLoaded = true;
break;
}
}
if (!configLoaded) {
LOGGER.warn("Some processes that the driver needs to connect to have " +
"not registered with Redis, so retrying. Have you run " +
"'ray start' on this node?");
try {
TimeUnit.SECONDS.sleep(1);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
} else {
break;
}
}
if (!configLoaded) {
throw new RayException("Some processes that the driver needs to connect to have " +
"not registered with Redis. Have you run 'ray start' on this node?");
}
}
private void loadConfigFromGcs() {
rayConfig.rayletConfigParameters.clear();
for (Map.Entry<String, String> entry : gcsClient.getInternalConfig().entrySet()) {
rayConfig.rayletConfigParameters.put(entry.getKey(), entry.getValue());
}
}
@Override
public void start() {
if (rayConfig.getRedisAddress() == null) {
manager = new RunManager(rayConfig);
manager.startRayProcesses(true);
try {
if (rayConfig.workerMode == WorkerType.DRIVER && rayConfig.getRedisAddress() == null) {
// Set it to true before `RunManager.startRayHead` so `Ray.shutdown()` can still kill
// Ray processes even if `Ray.init()` failed.
startRayHead = true;
RunManager.startRayHead(rayConfig);
}
Preconditions.checkNotNull(rayConfig.getRedisAddress());
updateSessionDir();
// Expose ray ABI symbols which may be depended by other shared
// libraries such as libstreaming_java.so.
// See BUILD.bazel:libcore_worker_library_java.so
Preconditions.checkNotNull(rayConfig.sessionDir);
JniUtils.loadLibrary(rayConfig.sessionDir, BinaryFileUtil.CORE_WORKER_JAVA_LIBRARY, true);
if (rayConfig.workerMode == WorkerType.DRIVER) {
RunManager.getAddressInfoAndFillConfig(rayConfig);
}
gcsClient = new GcsClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
loadConfigFromGcs();
if (rayConfig.getJobId() == JobId.NIL) {
rayConfig.setJobId(gcsClient.nextJobId());
}
int numWorkersPerProcess =
rayConfig.workerMode == WorkerType.DRIVER ? 1 : rayConfig.numWorkersPerProcess;
byte[] serializedJobConfig = null;
if (rayConfig.workerMode == WorkerType.DRIVER) {
JobConfig.Builder jobConfigBuilder =
JobConfig.newBuilder()
.setNumJavaWorkersPerProcess(rayConfig.numWorkersPerProcess)
.addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker)
.putAllWorkerEnv(rayConfig.workerEnv)
.addAllCodeSearchPath(rayConfig.codeSearchPath);
serializedJobConfig = jobConfigBuilder.build().toByteArray();
}
Map<String, String> rayletConfigStringMap = new HashMap<>();
for (Map.Entry<String, Object> entry : rayConfig.rayletConfigParameters.entrySet()) {
rayletConfigStringMap.put(entry.getKey(), entry.getValue().toString());
}
nativeInitialize(rayConfig.workerMode.getNumber(),
rayConfig.nodeIp, rayConfig.getNodeManagerPort(),
rayConfig.workerMode == WorkerType.DRIVER ? System.getProperty("user.dir") : "",
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName,
(rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(),
new GcsClientOptions(rayConfig), numWorkersPerProcess,
rayConfig.logDir, rayletConfigStringMap, serializedJobConfig);
taskExecutor = new NativeTaskExecutor(this);
workerContext = new NativeWorkerContext();
objectStore = new NativeObjectStore(workerContext, shutdownLock);
taskSubmitter = new NativeTaskSubmitter();
LOGGER.debug("RayNativeRuntime started with store {}, raylet {}",
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName);
} catch (Exception e) {
if (startRayHead) {
try {
RunManager.stopRay();
} catch (Exception e2) {
// Ignore
}
}
throw e;
}
gcsClient = new GcsClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
if (rayConfig.getJobId() == JobId.NIL) {
rayConfig.setJobId(gcsClient.nextJobId());
}
int numWorkersPerProcess =
rayConfig.workerMode == WorkerType.DRIVER ? 1 : rayConfig.numWorkersPerProcess;
byte[] serializedJobConfig = null;
if (rayConfig.workerMode == WorkerType.DRIVER) {
JobConfig.Builder jobConfigBuilder =
JobConfig.newBuilder()
.setNumJavaWorkersPerProcess(rayConfig.numWorkersPerProcess)
.addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker)
.putAllWorkerEnv(rayConfig.workerEnv)
.addAllCodeSearchPath(rayConfig.codeSearchPath);
serializedJobConfig = jobConfigBuilder.build().toByteArray();
}
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
nativeInitialize(rayConfig.workerMode.getNumber(),
rayConfig.nodeIp, rayConfig.getNodeManagerPort(),
rayConfig.workerMode == WorkerType.DRIVER ? System.getProperty("user.dir") : "",
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName,
(rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(),
new GcsClientOptions(rayConfig), numWorkersPerProcess,
rayConfig.logDir, rayConfig.rayletConfigParameters, serializedJobConfig);
taskExecutor = new NativeTaskExecutor(this);
workerContext = new NativeWorkerContext();
objectStore = new NativeObjectStore(workerContext, shutdownLock);
taskSubmitter = new NativeTaskSubmitter();
LOGGER.debug("RayNativeRuntime started with store {}, raylet {}",
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName);
}
@Override
@@ -183,27 +155,21 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
try {
if (rayConfig.workerMode == WorkerType.DRIVER) {
nativeShutdown();
if (null != manager) {
manager.cleanup();
manager = null;
if (startRayHead) {
startRayHead = false;
RunManager.stopRay();
}
}
if (null != gcsClient) {
gcsClient.destroy();
gcsClient = null;
}
RayConfig.reset();
LOGGER.debug("RayNativeRuntime shutdown");
} finally {
writeLock.unlock();
}
}
// For test purpose only
public RunManager getRunManager() {
return manager;
}
@Override
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
Preconditions.checkArgument(Double.compare(capacity, 0) >= 0);
@@ -2,7 +2,6 @@ package io.ray.runtime.config;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.typesafe.config.Config;
import com.typesafe.config.ConfigException;
@@ -12,17 +11,15 @@ import com.typesafe.config.ConfigValue;
import io.ray.api.id.JobId;
import io.ray.runtime.generated.Common.WorkerType;
import io.ray.runtime.util.NetworkUtil;
import io.ray.runtime.util.ResourceUtil;
import java.io.File;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.lang3.BooleanUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.math.NumberUtils;
/**
* Configurations of Ray runtime.
@@ -33,13 +30,6 @@ public class RayConfig {
public static final String DEFAULT_CONFIG_FILE = "ray.default.conf";
public static final String CUSTOM_CONFIG_FILE = "ray.conf";
private static final Random RANDOM = new Random();
private static final DateTimeFormatter DATE_TIME_FORMATTER =
DateTimeFormatter.ofPattern("yyyy-MM-dd_HH-mm-ss");
private static final String DEFAULT_TEMP_DIR = "/tmp/ray";
private Config config;
/**
@@ -48,54 +38,25 @@ public class RayConfig {
public final String nodeIp;
public final WorkerType workerMode;
public final RunMode runMode;
public final Map<String, Double> resources;
private JobId jobId;
public String sessionDir;
public String logDir;
public final List<String> libraryPath;
public final List<String> classpath;
public final List<String> jvmParameters;
private String redisAddress;
private String redisIp;
private Integer redisPort;
public final int headRedisPort;
public final int[] redisShardPorts;
public final int numberRedisShards;
public final String headRedisPassword;
public final String redisPassword;
// RPC socket name of object store.
public String objectStoreSocketName;
public final Long objectStoreSize;
// RPC socket name of Raylet.
public String rayletSocketName;
// Listening port for node manager.
public int nodeManagerPort;
public final Map<String, String> rayletConfigParameters;
public final Map<String, Object> rayletConfigParameters;
public List<String> codeSearchPath;
public final String pythonWorkerCommand;
public final List<String> codeSearchPath;
private static volatile RayConfig instance = null;
public static RayConfig getInstance() {
if (instance == null) {
synchronized (RayConfig.class) {
if (instance == null) {
instance = RayConfig.create();
}
}
}
return instance;
}
public static void reset() {
synchronized (RayConfig.class) {
instance = null;
}
}
public final List<String> headArgs;
public final int numWorkersPerProcess;
@@ -140,15 +101,6 @@ public class RayConfig {
} else {
nodeIp = NetworkUtil.getIpAddress(null);
}
// Resources.
resources = ResourceUtil.getResourcesMapFromString(
config.getString("ray.resources"));
if (isDriver) {
if (!resources.containsKey("CPU")) {
int numCpu = Runtime.getRuntime().availableProcessors();
resources.put("CPU", numCpu * 1.0);
}
}
// Job id.
String jobId = config.getString("ray.job.id");
if (!jobId.isEmpty()) {
@@ -168,25 +120,16 @@ public class RayConfig {
}
}
workerEnv = workerEnvBuilder.build();
updateSessionDir();
// Object store configurations.
objectStoreSize = config.getBytes("ray.object-store.size");
updateSessionDir(null);
// Library path.
libraryPath = config.getStringList("ray.library.path");
// Custom classpath.
classpath = config.getStringList("ray.classpath");
// Custom worker jvm parameters.
if (config.hasPath("ray.worker.jvm-parameters")) {
jvmParameters = config.getStringList("ray.worker.jvm-parameters");
} else {
jvmParameters = ImmutableList.of();
// Object store socket name.
if (config.hasPath("ray.object-store.socket-name")) {
objectStoreSocketName = config.getString("ray.object-store.socket-name");
}
if (config.hasPath("ray.worker.python-command")) {
pythonWorkerCommand = config.getString("ray.worker.python-command");
} else {
pythonWorkerCommand = null;
// Raylet socket name.
if (config.hasPath("ray.raylet.socket-name")) {
rayletSocketName = config.getString("ray.raylet.socket-name");
}
// Redis configurations.
@@ -198,17 +141,6 @@ public class RayConfig {
this.redisAddress = null;
}
if (config.hasPath("ray.redis.head-port")) {
headRedisPort = config.getInt("ray.redis.head-port");
} else {
headRedisPort = NetworkUtil.getUnusedPort();
}
numberRedisShards = config.getInt("ray.redis.shard-number");
redisShardPorts = new int[numberRedisShards];
for (int i = 0; i < numberRedisShards; i++) {
redisShardPorts[i] = NetworkUtil.getUnusedPort();
}
headRedisPassword = config.getString("ray.redis.head-password");
redisPassword = config.getString("ray.redis.password");
// Raylet node manager port.
if (config.hasPath("ray.raylet.node-manager-port")) {
@@ -216,7 +148,6 @@ public class RayConfig {
} else {
Preconditions.checkState(workerMode != WorkerType.WORKER,
"Worker started by raylet should accept the node manager port from raylet.");
nodeManagerPort = NetworkUtil.getUnusedPort();
}
// Raylet parameters.
@@ -224,13 +155,27 @@ public class RayConfig {
Config rayletConfig = config.getConfig("ray.raylet.config");
for (Map.Entry<String, ConfigValue> entry : rayletConfig.entrySet()) {
Object value = entry.getValue().unwrapped();
rayletConfigParameters.put(entry.getKey(), value == null ? "" : value.toString());
if (value != null) {
if (value instanceof String) {
String valueString = (String) value;
Boolean booleanValue = BooleanUtils.toBooleanObject(valueString);
if (booleanValue != null) {
value = booleanValue;
} else if (NumberUtils.isParsable(valueString)) {
value = NumberUtils.createNumber(valueString);
}
}
rayletConfigParameters.put(entry.getKey(), value);
}
}
// Job code search path.
String codeSearchPathString = null;
if (config.hasPath("ray.job.code-search-path")) {
codeSearchPath = Arrays.asList(
config.getString("ray.job.code-search-path").split(":"));
codeSearchPathString = config.getString("ray.job.code-search-path");
}
if (!StringUtils.isEmpty(codeSearchPathString)) {
codeSearchPath = Arrays.asList(codeSearchPathString.split(":"));
} else {
codeSearchPath = Collections.emptyList();
}
@@ -258,6 +203,8 @@ public class RayConfig {
numWorkersPerProcess = config.getInt("ray.job.num-java-workers-per-process");
}
headArgs = config.getStringList("ray.head-args");
// Validate config.
validate();
}
@@ -267,24 +214,12 @@ public class RayConfig {
Preconditions.checkState(this.redisAddress == null, "Redis address was already set");
this.redisAddress = redisAddress;
String[] ipAndPort = redisAddress.split(":");
Preconditions.checkArgument(ipAndPort.length == 2, "Invalid redis address.");
this.redisIp = ipAndPort[0];
this.redisPort = Integer.parseInt(ipAndPort[1]);
}
public String getRedisAddress() {
return redisAddress;
}
public String getRedisIp() {
return redisIp;
}
public Integer getRedisPort() {
return redisPort;
}
public void setJobId(JobId jobId) {
this.jobId = jobId;
}
@@ -298,11 +233,7 @@ public class RayConfig {
}
public void setSessionDir(String sessionDir) {
this.sessionDir = sessionDir;
}
public String getSessionDir() {
return sessionDir;
updateSessionDir(sessionDir);
}
public Config getInternalConfig() {
@@ -312,7 +243,8 @@ public class RayConfig {
/**
* Renders the config value as a HOCON string.
*/
public String render() {
@Override
public String toString() {
// These items might be dynamically generated or mutated at runtime.
// Explicitly include them.
Map<String, Object> dynamic = new HashMap<>();
@@ -321,24 +253,19 @@ public class RayConfig {
dynamic.put("ray.object-store.socket-name", objectStoreSocketName);
dynamic.put("ray.raylet.node-manager-port", nodeManagerPort);
dynamic.put("ray.address", redisAddress);
dynamic.put("ray.job.code-search-path", codeSearchPath);
Config toRender = ConfigFactory.parseMap(dynamic).withFallback(config);
return toRender.root().render(ConfigRenderOptions.concise());
}
private void updateSessionDir() {
private void updateSessionDir(String sessionDir) {
// session dir
if (workerMode == WorkerType.DRIVER) {
final int minBound = 100000;
final int maxBound = 999999;
final String sessionName = String.format("session_%s_%d", DATE_TIME_FORMATTER.format(
LocalDateTime.now()), RANDOM.nextInt(maxBound - minBound) + minBound);
sessionDir = String.format("%s/%s", DEFAULT_TEMP_DIR, sessionName);
} else if (workerMode == WorkerType.WORKER) {
sessionDir = removeTrailingSlash(config.getString("ray.session-dir"));
} else {
throw new RuntimeException("Unknown worker type.");
if (config.hasPath("ray.session-dir")) {
sessionDir = config.getString("ray.session-dir");
}
if (sessionDir != null) {
sessionDir = removeTrailingSlash(sessionDir);
}
this.sessionDir = sessionDir;
// Log dir.
String localLogDir = null;
@@ -350,34 +277,6 @@ public class RayConfig {
} else {
logDir = localLogDir;
}
// Object store socket name.
String localObjectStoreSocketName = null;
if (config.hasPath("ray.object-store.socket-name")) {
localObjectStoreSocketName = config.getString("ray.object-store.socket-name");
}
if (Strings.isNullOrEmpty(localObjectStoreSocketName)) {
objectStoreSocketName = String.format("%s/sockets/object_store", sessionDir);
} else {
objectStoreSocketName = localObjectStoreSocketName;
}
// Raylet socket name.
String localRayletSocketName = null;
if (config.hasPath("ray.raylet.socket-name")) {
localRayletSocketName = config.getString("ray.raylet.socket-name");
}
if (Strings.isNullOrEmpty(localRayletSocketName)) {
rayletSocketName = String.format("%s/sockets/raylet", sessionDir);
} else {
rayletSocketName = localRayletSocketName;
}
}
@Override
public String toString() {
return render();
}
/**
@@ -43,16 +43,6 @@ public class RuntimeContextImpl implements RuntimeContext {
return runtime.getGcsClient().wasCurrentActorRestarted(getCurrentActorId());
}
@Override
public String getRayletSocketName() {
return runtime.getRayConfig().rayletSocketName;
}
@Override
public String getObjectStoreSocketName() {
return runtime.getRayConfig().objectStoreSocketName;
}
@Override
public boolean isSingleProcess() {
return RunMode.SINGLE_PROCESS == runtime.getRayConfig().runMode;
@@ -1,5 +1,6 @@
package io.ray.runtime.gcs;
import com.google.common.base.Preconditions;
import io.ray.runtime.config.RayConfig;
/**
@@ -11,8 +12,10 @@ public class GcsClientOptions {
public String password;
public GcsClientOptions(RayConfig rayConfig) {
ip = rayConfig.getRedisIp();
port = rayConfig.getRedisPort();
String[] ipAndPort = rayConfig.getRedisAddress().split(":");
Preconditions.checkArgument(ipAndPort.length == 2, "Invalid redis address.");
ip = ipAndPort[0];
port = Integer.parseInt(ipAndPort[1]);
password = rayConfig.redisPassword;
}
}
@@ -1,31 +1,21 @@
package io.ray.runtime.runner;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.gson.Gson;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.util.BinaryFileUtil;
import io.ray.runtime.util.ResourceUtil;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.jedis.Jedis;
/**
* Ray service management on one box.
@@ -34,97 +24,78 @@ public class RunManager {
private static final Logger LOGGER = LoggerFactory.getLogger(RunManager.class);
private static final String WORKER_CLASS = "io.ray.runtime.runner.worker.DefaultWorker";
private static final String SESSION_LATEST = "session_latest";
private RayConfig rayConfig;
private List<Pair<String, Process>> processes;
private static final int KILL_PROCESS_WAIT_TIMEOUT_SECONDS = 1;
public RunManager(RayConfig rayConfig) {
this.rayConfig = rayConfig;
processes = new ArrayList<>();
createTempDirs();
}
public void cleanup() {
// Terminate the processes in the reversed order of creating them.
// Because raylet needs to exit before object store, otherwise it
// cannot exit gracefully.
for (int i = processes.size() - 1; i >= 0; --i) {
Pair<String, Process> pair = processes.get(i);
terminateProcess(pair.getLeft(), pair.getRight());
}
}
public void terminateProcess(String name, Process p) {
int numAttempts = 0;
while (p.isAlive()) {
if (numAttempts == 0) {
LOGGER.debug("Terminating process {}.", name);
p.destroy();
} else {
LOGGER.debug("Terminating process {} forcibly.", name);
p.destroyForcibly();
}
try {
p.waitFor(KILL_PROCESS_WAIT_TIMEOUT_SECONDS, TimeUnit.SECONDS);
} catch (InterruptedException e) {
LOGGER.warn("Got InterruptedException while waiting for process {}" +
" to be terminated.", name);
}
numAttempts++;
}
LOGGER.debug("Process {} is now terminated.", name);
}
private static final Pattern pattern = Pattern.compile("--address='([^']+)'");
/**
* Get processes by name. For test purposes only.
* Start the head node.
*/
public List<Process> getProcesses(String name) {
return processes.stream().filter(pair -> pair.getLeft().equals(name)).map(Pair::getRight)
.collect(Collectors.toList());
}
private void createTempDirs() {
public static void startRayHead(RayConfig rayConfig) {
LOGGER.debug("Starting ray runtime @ {}.", rayConfig.nodeIp);
String codeSearchPath;
if (!rayConfig.codeSearchPath.isEmpty()) {
codeSearchPath = Joiner.on(File.pathSeparator).join(rayConfig.codeSearchPath);
} else {
codeSearchPath = System.getProperty("java.class.path");
}
List<String> command = new ArrayList<>();
command.add("ray");
command.add("start");
command.add("--head");
command.add("--redis-password");
command.add(rayConfig.redisPassword);
command.add("--system-config=" + new Gson().toJson(rayConfig.rayletConfigParameters));
command.add("--code-search-path=" + codeSearchPath);
command.addAll(rayConfig.headArgs);
String output;
try {
FileUtils.forceMkdir(new File(rayConfig.logDir));
FileUtils.forceMkdir(new File(rayConfig.rayletSocketName).getParentFile());
FileUtils.forceMkdir(new File(rayConfig.objectStoreSocketName).getParentFile());
// Remove session_latest first, and then create a new symbolic link for session_latest.
final String parentOfSessionDir = new File(rayConfig.sessionDir).getParent();
final File sessionLatest = new File(
String.format("%s/%s", parentOfSessionDir, SESSION_LATEST));
if (sessionLatest.exists()) {
sessionLatest.delete();
}
Files.createSymbolicLink(
Paths.get(sessionLatest.getAbsolutePath()),
Paths.get(rayConfig.sessionDir));
} catch (IOException e) {
LOGGER.error("Couldn't create temp directories.", e);
throw new RuntimeException(e);
output = runCommand(command);
} catch (Exception e) {
throw new RuntimeException("Failed to start Ray runtime.", e);
}
Matcher matcher = pattern.matcher(output);
if (matcher.find()) {
String redisAddress = matcher.group(1);
rayConfig.setRedisAddress(redisAddress);
} else {
throw new RuntimeException("Redis address is not found. output: " + output);
}
LOGGER.info("Ray runtime started @ {}.", rayConfig.nodeIp);
}
/**
* @return Log files for stdout and stderr.
* Stop ray.
*/
private Pair<File, File> getLogFiles(String logDir, String processName) {
int suffixIndex = 0;
while (true) {
String suffix = suffixIndex == 0 ? "" : "." + suffixIndex;
File stdout = new File(String.format("%s/%s%s.out", logDir, suffix, processName));
File stderr = new File(String.format("%s/%s%s.err", logDir, suffix, processName));
if (!stdout.exists() && !stderr.exists()) {
return ImmutablePair.of(stdout, stderr);
}
suffixIndex += 1;
public static void stopRay() {
List<String> command = new ArrayList<>();
command.add("ray");
command.add("stop");
command.add("--force");
try {
runCommand(command);
} catch (Exception e) {
throw new RuntimeException("Failed to stop ray.", e);
}
}
public static void getAddressInfoAndFillConfig(RayConfig rayConfig) {
// NOTE(kfstorm): This method depends on an internal Python API of ray to get the
// address info of the local node.
String script = String.format("import ray;"
+ " print(ray._private.services.get_address_info_from_redis("
+ "'%s', '%s', redis_password='%s', no_warning=True))",
rayConfig.getRedisAddress(), rayConfig.nodeIp, rayConfig.redisPassword);
List<String> command = Arrays.asList("python", "-c", script);
String output = null;
try {
output = runCommand(command);
JsonObject addressInfo = new JsonParser().parse(output).getAsJsonObject();
rayConfig.rayletSocketName = addressInfo.get("raylet_socket_name").getAsString();
rayConfig.objectStoreSocketName = addressInfo.get("object_store_address").getAsString();
rayConfig.nodeManagerPort = addressInfo.get("node_manager_port").getAsInt();
} catch (Exception e) {
throw new RuntimeException("Failed to get address info. Output: " + output, e);
}
}
@@ -132,284 +103,22 @@ public class RunManager {
* Start a process.
*
* @param command The command to start the process with.
* @param env Environment variables.
* @param name Process name.
*/
private void startProcess(List<String> command, Map<String, String> env, String name) {
private static String runCommand(List<String> command) throws IOException, InterruptedException {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Starting process {} with command: {}", name,
Joiner.on(" ").join(command));
LOGGER.debug("Starting process with command: {}", Joiner.on(" ").join(command));
}
ProcessBuilder builder = new ProcessBuilder(command);
String stdout = "";
String stderr = "";
// Set stdout and stderr paths.
Pair<File, File> logFiles = getLogFiles(rayConfig.logDir, name);
builder.redirectOutput(logFiles.getLeft());
builder.redirectError(logFiles.getRight());
// Set environment variables.
if (env != null && !env.isEmpty()) {
builder.environment().putAll(env);
}
Process p;
try {
p = builder.start();
} catch (IOException e) {
LOGGER.error("Failed to start process " + name, e);
throw new RuntimeException("Failed to start process " + name, e);
}
// Wait 1000 ms and check whether the process is alive.
try {
TimeUnit.MILLISECONDS.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
if (!p.isAlive()) {
String message = String.format("Failed to start %s. Exit code: %d.",
name, p.exitValue());
message += String.format(" Logs are redirected to %s and %s.", stdout, stderr);
throw new RuntimeException(message);
}
processes.add(Pair.of(name, p));
if (LOGGER.isDebugEnabled()) {
String message = String.format("%s process started.", name);
message += String.format(" Logs are redirected to %s and %s.", stdout, stderr);
LOGGER.debug(message);
ProcessBuilder builder = new ProcessBuilder(command).redirectErrorStream(true);
Process p = builder.start();
String output = IOUtils.toString(p.getInputStream(), Charset.defaultCharset());
p.waitFor();
if (p.exitValue() != 0) {
String sb = "The exit value of the process is " + p.exitValue()
+ ". Command: " + Joiner.on(" ").join(command) + "\n"
+ "output:\n" + output;
throw new RuntimeException(sb);
}
return output;
}
/**
* Start all Ray processes on this node.
*
* @param isHead Whether this node is the head node. If true, redis server will be started.
*/
public void startRayProcesses(boolean isHead) {
LOGGER.debug("Starting ray runtime @ {}.", rayConfig.nodeIp);
try {
if (isHead) {
startGcs();
}
startObjectStore();
startRaylet(isHead);
LOGGER.info("Ray runtime started @ {}.", rayConfig.nodeIp);
} catch (Exception e) {
// Clean up started processes.
cleanup();
LOGGER.error("Failed to start ray runtime.", e);
throw new RuntimeException("Failed to start ray runtime.", e);
}
}
private void startGcs() {
// start primary redis
String primary = startRedisInstance(rayConfig.nodeIp,
rayConfig.headRedisPort, rayConfig.headRedisPassword, null);
rayConfig.setRedisAddress(primary);
try (Jedis client = new Jedis("127.0.0.1", rayConfig.headRedisPort)) {
if (!Strings.isNullOrEmpty(rayConfig.headRedisPassword)) {
client.auth(rayConfig.headRedisPassword);
}
client.set("UseRaylet", "1");
// Set job counter to compute job id.
client.set("JobCounter", "0");
// Register the number of Redis shards in the primary shard, so that clients
// know how many redis shards to expect under RedisShards.
client.set("NumRedisShards", Integer.toString(rayConfig.numberRedisShards));
// Set session dir for this cluster, so that the drivers which connected to this
// cluster will fetch this session dir as its self's session dir.
client.set("session_dir", rayConfig.getSessionDir());
// start redis shards
for (int i = 0; i < rayConfig.numberRedisShards; i++) {
String shard = startRedisInstance(rayConfig.nodeIp,
rayConfig.redisShardPorts[i], rayConfig.headRedisPassword, i);
client.rpush("RedisShards", shard);
}
}
// start gcs server
String redisPasswordOption = "";
if (!Strings.isNullOrEmpty(rayConfig.headRedisPassword)) {
redisPasswordOption = rayConfig.headRedisPassword;
}
// See `src/ray/gcs/gcs_server/gcs_server_main.cc` for the meaning of each parameter.
final File gcsServerFile = BinaryFileUtil.getNativeFile(
rayConfig.sessionDir, BinaryFileUtil.GCS_SERVER_BINARY_NAME);
Preconditions.checkState(gcsServerFile.setExecutable(true));
List<String> command = ImmutableList.of(
gcsServerFile.getAbsolutePath(),
String.format("--redis_address=%s", rayConfig.getRedisIp()),
String.format("--redis_port=%d", rayConfig.getRedisPort()),
String.format("--config_list=%s",
rayConfig.rayletConfigParameters.entrySet().stream()
.map(entry -> entry.getKey() + "," + entry.getValue()).collect(Collectors
.joining(","))),
String.format("--redis_password=%s", redisPasswordOption)
);
startProcess(command, null, "gcs_server");
}
private String startRedisInstance(String ip, int port, String password, Integer shard) {
final File redisServerFile = BinaryFileUtil.getNativeFile(
rayConfig.sessionDir, BinaryFileUtil.REDIS_SERVER_BINARY_NAME);
Preconditions.checkState(redisServerFile.setExecutable(true));
// The redis module file.
File redisModule = BinaryFileUtil.getNativeFile(
rayConfig.sessionDir, BinaryFileUtil.REDIS_MODULE_LIBRARY_NAME);
Preconditions.checkState(redisModule.setExecutable(true));
List<String> command = Lists.newArrayList(
// The redis-server executable file.
redisServerFile.getAbsolutePath(),
"--protected-mode",
"no",
"--port",
String.valueOf(port),
"--loglevel",
"warning",
"--loadmodule",
// The redis module file.
redisModule.getAbsolutePath()
);
if (!Strings.isNullOrEmpty(password)) {
command.add("--requirepass ");
command.add(password);
}
String name = shard == null ? "redis" : "redis-shard_" + shard;
startProcess(command, null, name);
try (Jedis client = new Jedis("127.0.0.1", port)) {
if (!Strings.isNullOrEmpty(password)) {
client.auth(password);
}
// Configure Redis to only generate notifications for the export keys.
client.configSet("notify-keyspace-events", "Kl");
// Put a time stamp in Redis to indicate when it was started.
client.set("redis_start_time", LocalDateTime.now().toString());
}
return ip + ":" + port;
}
private void startRaylet(boolean isHead) throws IOException {
int hardwareConcurrency = Runtime.getRuntime().availableProcessors();
int maximumStartupConcurrency = Math.max(1,
Math.min(rayConfig.resources.getOrDefault("CPU", 0.0).intValue(), hardwareConcurrency));
String redisPasswordOption = "";
if (!Strings.isNullOrEmpty(rayConfig.headRedisPassword)) {
redisPasswordOption = rayConfig.headRedisPassword;
}
// See `src/ray/raylet/main.cc` for the meaning of each parameter.
final File rayletFile = BinaryFileUtil.getNativeFile(
rayConfig.sessionDir, BinaryFileUtil.RAYLET_BINARY_NAME);
Preconditions.checkState(rayletFile.setExecutable(true));
List<String> command = ImmutableList.of(
rayletFile.getAbsolutePath(),
String.format("--raylet_socket_name=%s", rayConfig.rayletSocketName),
String.format("--store_socket_name=%s", rayConfig.objectStoreSocketName),
String.format("--object_manager_port=%d", 0), // The object manager port.
// The node manager port.
String.format("--node_manager_port=%d", rayConfig.getNodeManagerPort()),
String.format("--node_ip_address=%s", rayConfig.nodeIp),
String.format("--redis_address=%s", rayConfig.getRedisIp()),
String.format("--redis_port=%d", rayConfig.getRedisPort()),
String.format("--num_initial_workers=%d", 0), // number of initial workers
String.format("--maximum_startup_concurrency=%d", maximumStartupConcurrency),
String.format("--static_resource_list=%s",
ResourceUtil.getResourcesStringFromMap(rayConfig.resources)),
String.format("--config_list=%s", rayConfig.rayletConfigParameters.entrySet().stream()
.map(entry -> entry.getKey() + "," + entry.getValue())
.collect(Collectors.joining(","))),
String.format("--python_worker_command=%s", buildPythonWorkerCommand()),
String.format("--java_worker_command=%s", buildWorkerCommand()),
String.format("--redis_password=%s", redisPasswordOption),
isHead ? "--head_node" : ""
);
startProcess(command, null, "raylet");
}
private String concatPath(Stream<String> stream) {
// TODO (hchen): Right now, raylet backend doesn't support worker command with spaces.
// Thus, we have to drop some some paths until that is fixed.
return stream.filter(s -> !s.contains(" ")).collect(Collectors.joining(":"));
}
private String buildWorkerCommand() throws IOException {
List<String> cmd = new ArrayList<>();
cmd.add("java");
cmd.add("-classpath");
// Generate classpath based on current classpath + user-defined classpath.
String classpath = concatPath(Stream.concat(
rayConfig.classpath.stream(),
Stream.of(System.getProperty("java.class.path").split(":"))
));
cmd.add(classpath);
// Write current config to a file, and set the file path as Java worker's config file.
// This allows users to set worker config by setting driver's system properties.
File workerConfigFile = new File(rayConfig.sessionDir + "/java_worker.conf");
FileUtils.write(workerConfigFile, rayConfig.render(), Charset.defaultCharset());
cmd.add("-Dray.config-file=" + workerConfigFile.getAbsolutePath());
if (!rayConfig.codeSearchPath.isEmpty()) {
cmd.add("-Dray.job.code-search-path=" +
String.join(":", rayConfig.codeSearchPath));
}
cmd.add("RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER");
cmd.addAll(rayConfig.jvmParameters);
// jvm options
cmd.add("RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER");
// Main class
cmd.add(WORKER_CLASS);
String command = Joiner.on(" ").join(cmd);
LOGGER.debug("Worker command is: {}", command);
return command;
}
private void startObjectStore() {
final File objectStoreFile = BinaryFileUtil.getNativeFile(
rayConfig.sessionDir, BinaryFileUtil.PLASMA_STORE_SERVER_BINARY_NAME);
Preconditions.checkState(objectStoreFile.setExecutable(true));
List<String> command = ImmutableList.of(
// The plasma store executable file.
objectStoreFile.getAbsolutePath(),
"-s",
rayConfig.objectStoreSocketName,
"-m",
rayConfig.objectStoreSize.toString()
);
startProcess(command, null, "plasma_store");
}
private String buildPythonWorkerCommand() {
// disable python worker start from raylet, which starts from java
if (rayConfig.pythonWorkerCommand == null) {
return "";
}
List<String> cmd = new ArrayList<>();
cmd.add(rayConfig.pythonWorkerCommand);
cmd.add("--node-ip-address=" + rayConfig.nodeIp);
cmd.add("--object-store-name=" + rayConfig.objectStoreSocketName);
cmd.add("--raylet-name=" + rayConfig.rayletSocketName);
cmd.add("--address=" + rayConfig.getRedisAddress());
String command = cmd.stream().collect(Collectors.joining(" "));
LOGGER.debug("python worker command: {}", command);
return command;
}
}
@@ -12,15 +12,6 @@ import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.SystemUtils;
public class BinaryFileUtil {
public static final String REDIS_SERVER_BINARY_NAME = "redis-server";
public static final String GCS_SERVER_BINARY_NAME = "gcs_server";
public static final String PLASMA_STORE_SERVER_BINARY_NAME = "plasma_store_server";
public static final String RAYLET_BINARY_NAME = "raylet";
public static final String REDIS_MODULE_LIBRARY_NAME = "libray_redis_module.so";
public static final String CORE_WORKER_JAVA_LIBRARY = "core_worker_library_java";
@@ -2,8 +2,9 @@ package io.ray.runtime.util;
import com.google.common.collect.Sets;
import com.sun.jna.NativeLibrary;
import io.ray.runtime.config.RayConfig;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -11,17 +12,7 @@ import org.slf4j.LoggerFactory;
public class JniUtils {
private static final Logger LOGGER = LoggerFactory.getLogger(JniUtils.class);
private static Set<String> loadedLibs = Sets.newHashSet();
/**
* Loads the native library specified by the <code>libraryName</code> argument.
* The <code>libraryName</code> argument must not contain any platform specific
* prefix, file extension or path.
*
* @param libraryName the name of the library.
*/
public static synchronized void loadLibrary(String libraryName) {
loadLibrary(libraryName, false);
}
private static String defaultDestDir;
/**
* Loads the native library specified by the <code>libraryName</code> argument.
@@ -29,15 +20,51 @@ public class JniUtils {
* prefix, file extension or path.
*
* @param libraryName the name of the library.
*/
public static synchronized void loadLibrary(String libraryName) {
loadLibrary(getDefaultDestDir(), libraryName);
}
/**
* Loads the native library specified by the <code>libraryName</code> argument.
* The <code>libraryName</code> argument must not contain any platform specific
* prefix, file extension or path.
*
* @param libraryName the name of the library.
* @param exportSymbols export symbols of library so that it can be used by other libs.
*/
public static synchronized void loadLibrary(String libraryName, boolean exportSymbols) {
loadLibrary(getDefaultDestDir(), libraryName, exportSymbols);
}
/**
* Loads the native library specified by the <code>libraryName</code> argument.
* The <code>libraryName</code> argument must not contain any platform specific
* prefix, file extension or path.
*
* @param destDir The destination dir the library to be extracted.
* @param libraryName the name of the library.
*/
public static synchronized void loadLibrary(String destDir, String libraryName) {
loadLibrary(destDir, libraryName, false);
}
/**
* Loads the native library specified by the <code>libraryName</code> argument.
* The <code>libraryName</code> argument must not contain any platform specific
* prefix, file extension or path.
*
* @param destDir The destination dir the library to be extracted.
* @param libraryName the name of the library.
* @param exportSymbols export symbols of library so that it can be used by other libs.
*/
public static synchronized void loadLibrary(String destDir, String libraryName,
boolean exportSymbols) {
if (!loadedLibs.contains(libraryName)) {
LOGGER.debug("Loading native library {}.", libraryName);
// Load native library.
String fileName = System.mapLibraryName(libraryName);
final String sessionDir = RayConfig.getInstance().sessionDir;
final File file = BinaryFileUtil.getNativeFile(sessionDir, fileName);
final File file = BinaryFileUtil.getNativeFile(destDir, fileName);
if (exportSymbols) {
// Expose library symbols using RTLD_GLOBAL which may be depended by other shared
@@ -50,4 +77,17 @@ public class JniUtils {
}
}
/**
* Cache the result so that multiple calls return the same dest dir.
*/
private static synchronized String getDefaultDestDir() {
if (defaultDestDir == null) {
try {
defaultDestDir = Files.createTempDirectory("native_libs").toString();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
return defaultDestDir;
}
}
@@ -18,9 +18,6 @@ ray {
// `CLUSTER`: Ray is running on one or more nodes, with multiple processes.
run-mode: CLUSTER
// Available resources on this node, for example "CPU:4,GPU:0".
resources: ""
// Configuration items about job.
job {
// If worker.mode is DRIVER, specify the job id.
@@ -56,40 +53,12 @@ ray {
max-backup-files: 10
}
// Custom worker jvm parameters.
worker.jvm-parameters: []
// Custom `java.library.path`
// Note, do not use `dir1:dir2` format, put each dir as a list item.
library.path: []
// Custom classpath.
// Note, do not use `dir1:dir2` format, put each dir as a list item.
classpath = []
// ----------------------
// Redis configurations
// ----------------------
redis {
// If `redis.server` isn't provided, which port we should use to start redis server.
// If `head-port` is not provided, it will be generated randomly.
// head-port: 6379
// Below passwords should be consistent with the one defined in python/ray/ray_constants.py.
// The password used to start the redis server on the head node.
head-password: "5241590000000000"
// The password used to connect to the redis server.
password: "5241590000000000"
// If `redis.server` isn't provided, how many Redis shards we should start in addition to the
// primary Redis shard. The ports of these shards will be `head-port + 1`, `head-port + 2`, etc.
shard-number: 1
}
// ----------------------------
// Object store configurations
// ----------------------------
object-store {
// Initial size of the object store.
size: 10 MB
}
// ----------------------------
@@ -97,12 +66,14 @@ ray {
// ----------------------------
raylet {
// See src/ray/ray_config_def.h for options.
// Below section takes effect only if Ray head is started by a driver.
config {
// TODO(zhuohan): enable this for java
put_small_object_in_memory_store: false
}
}
// Whether we enable job manager to submit and manage job.
enable-job-manager: false
// Below args will be appended as parameters of the `ray start` command.
// It takes effect only if Ray head is started by a driver.
head-args: []
}
@@ -2,6 +2,8 @@ package io.ray.runtime.config;
import io.ray.runtime.generated.Common.WorkerType;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -11,37 +13,39 @@ public class RayConfigTest {
@Test
public void testCreateRayConfig() {
Map<String, String> rayletConfig = new HashMap<>();
rayletConfig.put("one", "1");
rayletConfig.put("zero", "0");
rayletConfig.put("positive-integer", "123");
rayletConfig.put("negative-integer", "-123");
rayletConfig.put("float", "-123.456");
rayletConfig.put("true", "true");
rayletConfig.put("false", "false");
rayletConfig.put("string", "abc");
try {
System.setProperty("ray.job.code-search-path", "path/to/ray/job/resource/path");
for (Map.Entry<String, String> entry : rayletConfig.entrySet()) {
System.setProperty("ray.raylet.config." + entry.getKey(), entry.getValue());
}
RayConfig rayConfig = RayConfig.create();
Assert.assertEquals(WorkerType.DRIVER, rayConfig.workerMode);
Assert.assertEquals(Collections.singletonList("path/to/ray/job/resource/path"),
rayConfig.codeSearchPath);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("one"), 1);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("zero"), 0);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("positive-integer"), 123);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("negative-integer"), -123);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("float"), -123.456f);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("true"), true);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("false"), false);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("string"), "abc");
} finally {
// Unset system properties.
System.clearProperty("ray.job.code-search-path");
}
}
@Test
public void testGenerateHeadPortRandomly() {
boolean isSame = true;
final int port1 = RayConfig.create().headRedisPort;
// If we the 2 ports are the same, let's retry.
// This is used to avoid any flaky chance.
for (int i = 0; i < NUM_RETRIES; ++i) {
final int port2 = RayConfig.create().headRedisPort;
if (port1 != port2) {
isSame = false;
break;
for (String key : rayletConfig.keySet()) {
System.clearProperty("ray.raylet.config." + key);
}
}
Assert.assertFalse(isSame);
}
@Test
public void testSpecifyHeadPort() {
System.setProperty("ray.redis.head-port", "11111");
Assert.assertEquals(RayConfig.create().headRedisPort, 11111);
}
}
+2 -2
View File
@@ -44,8 +44,8 @@ TEST_ARGS=(-Dray.raylet.config.num_workers_per_process_java=10 -Dray.job.num-jav
echo "Running tests under cluster mode."
# TODO(hchen): Ideally, we should use the following bazel command to run Java tests. However, if there're skipped tests,
# TestNG will exit with code 2. And bazel treats it as test failure.
# bazel test //java:all_tests --action_env=ENABLE_MULTI_LANGUAGE_TESTS=1 --config=ci || cluster_exit_code=$?
ENABLE_MULTI_LANGUAGE_TESTS=1 run_testng java -cp "$ROOT_DIR"/../bazel-bin/java/all_tests_deploy.jar "${TEST_ARGS[@]}" org.testng.TestNG -d /tmp/ray_java_test_output "$ROOT_DIR"/testng.xml
# bazel test //java:all_tests --config=ci || cluster_exit_code=$?
run_testng java -cp "$ROOT_DIR"/../bazel-bin/java/all_tests_deploy.jar "${TEST_ARGS[@]}" org.testng.TestNG -d /tmp/ray_java_test_output "$ROOT_DIR"/testng.xml
echo "Running tests under single-process mode."
# bazel test //java:all_tests --jvmopt="-Dray.run-mode=SINGLE_PROCESS" --config=ci || single_exit_code=$?
@@ -1,127 +0,0 @@
package io.ray.test;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.gson.Gson;
import io.ray.api.Ray;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.util.NetworkUtil;
import java.io.File;
import java.lang.ProcessBuilder.Redirect;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
@Test(groups = {"cluster", "multiLanguage"})
public abstract class BaseMultiLanguageTest {
private static final Logger LOGGER = LoggerFactory.getLogger(BaseMultiLanguageTest.class);
private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/test/plasma_store_socket";
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket";
/**
* Execute an external command.
*
* @return Whether the command succeeded.
*/
private boolean executeCommand(List<String> command, int waitTimeoutSeconds,
Map<String, String> env) {
try {
LOGGER.info("Executing command: {}", String.join(" ", command));
ProcessBuilder processBuilder = new ProcessBuilder(command).redirectOutput(Redirect.INHERIT)
.redirectError(Redirect.INHERIT);
for (Entry<String, String> entry : env.entrySet()) {
processBuilder.environment().put(entry.getKey(), entry.getValue());
}
Process process = processBuilder.start();
process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
return process.exitValue() == 0;
} catch (Exception e) {
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
}
}
@BeforeClass(alwaysRun = true, inheritGroups = false)
public void setUp() {
// Delete existing socket files.
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
File file = new File(socket);
if (file.exists()) {
file.delete();
}
}
String nodeManagerPort = String.valueOf(NetworkUtil.getUnusedPort());
// jars in the `ray` wheel doesn't contains test classes, so we add test classes explicitly.
// Since mvn test classes contains `test` in path and bazel test classes is located at a jar
// with `test` included in the name, we can check classpath `test` to filter out test classes.
List<String> classpath = Stream.of(System.getProperty("java.class.path").split(":"))
.filter(s -> !s.contains(" ") && s.contains("test"))
.collect(Collectors.toList());
// Start ray cluster.
List<String> startCommand = Arrays.asList(
"ray",
"start",
"--head",
"--port=6379",
"--min-worker-port=0",
"--max-worker-port=0",
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
String.format("--node-manager-port=%s", nodeManagerPort),
"--load-code-from-local",
"--system-config=" + new Gson().toJson(RayConfig.create().rayletConfigParameters),
"--code-search-path=" + String.join(":", classpath)
);
if (!executeCommand(startCommand, 10, getRayStartEnv())) {
throw new RuntimeException("Couldn't start ray cluster.");
}
// Connect to the cluster.
Assert.assertFalse(Ray.isInitialized());
System.setProperty("ray.address", "127.0.0.1:6379");
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
System.setProperty("ray.raylet.node-manager-port", nodeManagerPort);
Ray.init();
}
/**
* @return The environment variables needed for the `ray start` command.
*/
protected Map<String, String> getRayStartEnv() {
return ImmutableMap.of();
}
@AfterClass(alwaysRun = true, inheritGroups = false)
public void tearDown() {
// Disconnect to the cluster.
Ray.shutdown();
System.clearProperty("ray.address");
System.clearProperty("ray.object-store.socket-name");
System.clearProperty("ray.raylet.socket-name");
System.clearProperty("ray.raylet.node-manager-port");
// Stop ray cluster.
final List<String> stopCommand = ImmutableList.of(
"ray",
"stop"
);
if (!executeCommand(stopCommand, 10, ImmutableMap.of())) {
throw new RuntimeException("Couldn't stop ray cluster");
}
}
}
@@ -1,46 +1,22 @@
package io.ray.test;
import com.google.common.collect.ImmutableList;
import io.ray.api.Ray;
import java.io.File;
import java.lang.reflect.Method;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
public class BaseTest {
private static final Logger LOGGER = LoggerFactory.getLogger(BaseTest.class);
private List<File> filesToDelete = ImmutableList.of();
@BeforeMethod(alwaysRun = true)
public void setUpBase(Method method) {
Assert.assertFalse(Ray.isInitialized());
Ray.init();
// These files need to be deleted after each test case.
filesToDelete = ImmutableList.of(
new File(Ray.getRuntimeContext().getRayletSocketName()),
new File(Ray.getRuntimeContext().getObjectStoreSocketName()),
// TODO(pcm): This is a workaround for the issue described
// in the PR description of https://github.com/ray-project/ray/pull/5450
// and should be fixed properly.
new File("/tmp/ray/test/raylet_socket")
);
// Make sure the files will be deleted even if the test doesn't exit gracefully.
filesToDelete.forEach(File::deleteOnExit);
}
@AfterMethod(alwaysRun = true)
public void tearDownBase() {
Ray.shutdown();
for (File file : filesToDelete) {
file.delete();
}
}
}
@@ -1,7 +1,6 @@
package io.ray.test;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
@@ -19,17 +18,19 @@ import java.io.InputStream;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
@Test(groups = {"cluster"})
public class CrossLanguageInvocationTest extends BaseTest {
private static final String PYTHON_MODULE = "test_cross_language_invocation";
@Override
protected Map<String, String> getRayStartEnv() {
@BeforeClass
public void beforeClass() {
// Delete and re-create the temp dir.
File tempDir = new File(
System.getProperty("java.io.tmpdir") + File.separator + "ray_cross_language_test");
@@ -48,7 +49,14 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
throw new RuntimeException(e);
}
return ImmutableMap.of("PYTHONPATH", tempDir.getAbsolutePath());
System.setProperty("ray.job.code-search-path",
System.getProperty("java.class.path") + File.pathSeparator
+ tempDir.getAbsolutePath());
}
@AfterClass
public void afterClass() {
System.clearProperty("ray.job.code-search-path");
}
@Test
@@ -16,12 +16,12 @@ public class GcsClientTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.resources", "A:8");
System.setProperty("ray.head-args.0", "--resources={\"A\":8}");
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.resources");
System.clearProperty("ray.head-args.0");
}
public void testGetAllNodeInfo() {
@@ -17,12 +17,12 @@ public class GlobalGcTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.object-store.size", "140 MB");
System.setProperty("ray.head-args.0", "--object-store-memory=" + 140L * 1024 * 1024);
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.object-store.size");
System.clearProperty("ray.head-args.0");
}
public static class LargeObjectWithCyclicRef {
@@ -5,7 +5,8 @@ import io.ray.api.Ray;
import org.testng.Assert;
import org.testng.annotations.Test;
public class MultiLanguageClusterTest extends BaseMultiLanguageTest {
@Test(groups = {"cluster"})
public class MultiLanguageClusterTest extends BaseTest {
public static String echo(String word) {
return word;
@@ -18,9 +18,6 @@ public class RayAlterSuiteListener implements IAlterSuiteListener {
XmlGroups groups = new XmlGroups();
XmlRun run = new XmlRun();
run.onExclude(excludedGroup);
if (!"1".equals(System.getenv("ENABLE_MULTI_LANGUAGE_TESTS"))) {
run.onExclude("multiLanguage");
}
groups.setRun(run);
suite.setGroups(groups);
}
@@ -8,7 +8,8 @@ import io.ray.runtime.object.ObjectSerializer;
import org.testng.Assert;
import org.testng.annotations.Test;
public class RaySerializerTest extends BaseMultiLanguageTest {
@Test(groups = {"cluster"})
public class RaySerializerTest extends BaseTest {
@Test
public void testSerializePyActor() {
@@ -25,7 +25,9 @@ public class RayletConfigTest extends BaseTest {
public static class TestActor {
public String getConfigValue() {
return TestUtils.getRuntime().getRayConfig().rayletConfigParameters.get(RAY_CONFIG_KEY);
return TestUtils.getRuntime().getRayConfig()
.rayletConfigParameters.get(RAY_CONFIG_KEY)
.toString();
}
}
@@ -11,13 +11,11 @@ public class RedisPasswordTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.redis.head-password", "12345678");
System.setProperty("ray.redis.password", "12345678");
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.redis.head-password");
System.clearProperty("ray.redis.password");
}
@@ -27,12 +27,12 @@ import org.testng.annotations.Test;
public class ReferenceCountingTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.object-store.size", "100 MB");
System.setProperty("ray.head-args.0", "--object-store-memory=" + 100L * 1024 * 1024);
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.object-store.size");
System.clearProperty("ray.head-args.0");
}
/**
@@ -20,12 +20,14 @@ public class ResourcesManagementTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.resources", "CPU:4,RES-A:4");
System.setProperty("ray.head-args.0", "--num-cpus=4");
System.setProperty("ray.head-args.1", "--resources={\"RES-A\":4}");
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.resources");
System.clearProperty("ray.head-args.0");
System.clearProperty("ray.head-args.1");
}
public static Integer echo(Integer number) {
@@ -14,8 +14,6 @@ import org.testng.annotations.Test;
public class RuntimeContextTest extends BaseTest {
private static JobId JOB_ID = getJobId();
private static String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket";
private static String OBJECT_STORE_SOCKET_NAME = "/tmp/ray/test/object_store_socket";
private static JobId getJobId() {
// Must be stable across different processes.
@@ -27,23 +25,16 @@ public class RuntimeContextTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.job.id", JOB_ID.toString());
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
System.setProperty("ray.object-store.socket-name", OBJECT_STORE_SOCKET_NAME);
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.job.id");
System.clearProperty("ray.raylet.socket-name");
System.clearProperty("ray.object-store.socket-name");
}
@Test
public void testRuntimeContextInDriver() {
Assert.assertEquals(JOB_ID, Ray.getRuntimeContext().getCurrentJobId());
Assert.assertEquals(RAYLET_SOCKET_NAME, Ray.getRuntimeContext().getRayletSocketName());
Assert.assertEquals(OBJECT_STORE_SOCKET_NAME,
Ray.getRuntimeContext().getObjectStoreSocketName());
}
public static class RuntimeContextTester {
@@ -51,9 +42,6 @@ public class RuntimeContextTest extends BaseTest {
public String testRuntimeContext(ActorId actorId) {
Assert.assertEquals(JOB_ID, Ray.getRuntimeContext().getCurrentJobId());
Assert.assertEquals(actorId, Ray.getRuntimeContext().getCurrentActorId());
Assert.assertEquals(RAYLET_SOCKET_NAME, Ray.getRuntimeContext().getRayletSocketName());
Assert.assertEquals(OBJECT_STORE_SOCKET_NAME,
Ray.getRuntimeContext().getObjectStoreSocketName());
return "ok";
}
}