[Java] Simplify Ray.init() by invoking ray start internally (#10762)

This commit is contained in:
Kai Yang
2020-12-04 14:33:45 +08:00
committed by GitHub
parent 8cebe1e79c
commit 21fcee28f9
39 changed files with 367 additions and 1085 deletions
+1 -1
View File
@@ -135,7 +135,7 @@ The latest Ray Java snapshot can be found in `sonatype repository <https://oss.s
.. note::
When you run ``pip install`` to install Ray, Java jars are installed as well. The above dependencies are only used to build your Java code and to run your code in local or single machine mode.
When you run ``pip install`` to install Ray, Java jars are installed as well. The above dependencies are only used to build your Java code and to run your code in local mode.
If you want to run your Java code in a multi-node Ray cluster, it's better to exclude Ray jars when packaging your code to avoid jar conficts if the versions (installed Ray with ``pip install`` and maven dependencies) don't match.
+1 -22
View File
@@ -70,6 +70,7 @@ define_java_module(
visibility = ["//visibility:public"],
deps = [
":io_ray_ray_api",
"@maven//:com_google_code_gson_gson",
"@maven//:com_google_guava_guava",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:com_typesafe_config",
@@ -134,27 +135,12 @@ filegroup(
],
)
native_java_binary("runtime", "raylet", "//:raylet")
native_java_binary("runtime", "plasma_store_server", "//:plasma_store_server")
native_java_binary("runtime", "redis-server", "//:redis-server")
native_java_binary("runtime", "gcs_server", "//:gcs_server")
native_java_binary("runtime", "libray_redis_module.so", "//:libray_redis_module.so")
native_java_library("runtime", "core_worker_library_java", "//:libcore_worker_library_java.so")
filegroup(
name = "java_native_deps",
srcs = [
":core_worker_library_java",
":gcs_server",
":libray_redis_module.so",
":plasma_store_server",
":raylet",
":redis-server",
],
)
@@ -252,13 +238,6 @@ genrule(
WORK_DIR="$$(pwd)"
rm -rf "$$WORK_DIR/python/ray/jars" && mkdir -p "$$WORK_DIR/python/ray/jars"
cp -f $(location //java:ray_dist_deploy.jar) "$$WORK_DIR/python/ray/jars/ray_dist.jar"
chmod +w "$$WORK_DIR/python/ray/jars/ray_dist.jar"
zip -d "$$WORK_DIR/python/ray/jars/ray_dist.jar" \
"native/*/gcs_server" \
"native/*/libray_redis_module.so" \
"native/*/plasma_store_server" \
"native/*/raylet" \
"native/*/redis-server"
date > $@
""",
local = 1,
@@ -28,16 +28,6 @@ public interface RuntimeContext {
*/
boolean wasCurrentActorRestarted();
/**
* Get the raylet socket name.
*/
String getRayletSocketName();
/**
* Get the object store socket name.
*/
String getObjectStoreSocketName();
/**
* Return true if Ray is running in single-process mode, false if Ray is running in cluster mode.
*/
+5
View File
@@ -39,6 +39,11 @@
<artifactId>ray-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.5</version>
</dependency>
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
@@ -16,7 +16,7 @@ public class DefaultRayRuntimeFactory implements RayRuntimeFactory {
@Override
public RayRuntime createRayRuntime() {
RayConfig rayConfig = RayConfig.getInstance();
RayConfig rayConfig = RayConfig.create();
LoggingUtil.setupLogging(rayConfig);
Logger logger = LoggerFactory.getLogger(DefaultRayRuntimeFactory.class);
@@ -57,7 +57,6 @@ public class RayDevRuntime extends AbstractRayRuntime {
taskSubmitter = null;
}
taskExecutor = null;
RayConfig.reset();
}
@Override
@@ -5,10 +5,8 @@ import io.ray.api.BaseActorHandle;
import io.ray.api.id.ActorId;
import io.ray.api.id.JobId;
import io.ray.api.id.UniqueId;
import io.ray.api.runtimecontext.NodeInfo;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.context.NativeWorkerContext;
import io.ray.runtime.exception.RayException;
import io.ray.runtime.exception.RayIntentionalSystemExitException;
import io.ray.runtime.gcs.GcsClient;
import io.ray.runtime.gcs.GcsClientOptions;
@@ -22,15 +20,12 @@ import io.ray.runtime.task.NativeTaskSubmitter;
import io.ray.runtime.task.TaskExecutor;
import io.ray.runtime.util.BinaryFileUtil;
import io.ray.runtime.util.JniUtils;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.apache.commons.io.FileUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -41,7 +36,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
private static final Logger LOGGER = LoggerFactory.getLogger(RayNativeRuntime.class);
private RunManager manager = null;
private boolean startRayHead = false;
/**
* In Java, GC runs in a standalone thread, and we can't control the exact
@@ -52,124 +47,101 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
*/
private final ReadWriteLock shutdownLock = new ReentrantReadWriteLock();
public RayNativeRuntime(RayConfig rayConfig) {
super(rayConfig);
}
static {
LOGGER.debug("Loading native libraries.");
// Expose ray ABI symbols which may be depended by other shared
// libraries such as libstreaming_java.so.
// See BUILD.bazel:libcore_worker_library_java.so
final RayConfig rayConfig = RayConfig.getInstance();
if (rayConfig.getRedisAddress() != null && rayConfig.workerMode == WorkerType.DRIVER) {
// Fetch session dir from GCS if this is a driver that is connecting to the existing GCS.
private void updateSessionDir() {
if (rayConfig.workerMode == WorkerType.DRIVER) {
// Fetch session dir from GCS if this is a driver.
RedisClient client = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
final String sessionDir = client.get("session_dir", null);
Preconditions.checkNotNull(sessionDir);
rayConfig.setSessionDir(sessionDir);
}
JniUtils.loadLibrary(BinaryFileUtil.CORE_WORKER_JAVA_LIBRARY, true);
LOGGER.debug("Native libraries loaded.");
try {
FileUtils.forceMkdir(new File(rayConfig.logDir));
} catch (IOException e) {
throw new RuntimeException("Failed to create the log directory.", e);
}
}
public RayNativeRuntime(RayConfig rayConfig) {
super(rayConfig);
loadConfigFromGcs(rayConfig);
}
private static void loadConfigFromGcs(RayConfig rayConfig) {
if (rayConfig.getRedisAddress() != null) {
GcsClient tempGcsClient =
new GcsClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
for (Map.Entry<String, String> entry :
tempGcsClient.getInternalConfig().entrySet()) {
rayConfig.rayletConfigParameters.put(entry.getKey(), entry.getValue());
}
if (rayConfig.workerMode == WorkerType.DRIVER) {
// Keep this method logic in sync with `services.get_address_info_from_redis_helper`
int numRetries = 5;
int retryCount = 0;
boolean configLoaded = false;
while (retryCount++ < numRetries) {
for (NodeInfo nodeInfo : tempGcsClient.getAllNodeInfo()) {
if (rayConfig.nodeIp.equals(nodeInfo.nodeAddress) ||
(nodeInfo.nodeAddress.equals("127.0.0.1") &&
rayConfig.nodeIp.equals(rayConfig.getRedisAddress()))) {
rayConfig.objectStoreSocketName = nodeInfo.objectStoreSocketName;
rayConfig.rayletSocketName = nodeInfo.rayletSocketName;
rayConfig.nodeManagerPort = nodeInfo.nodeManagerPort;
configLoaded = true;
break;
}
}
if (!configLoaded) {
LOGGER.warn("Some processes that the driver needs to connect to have " +
"not registered with Redis, so retrying. Have you run " +
"'ray start' on this node?");
try {
TimeUnit.SECONDS.sleep(1);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
} else {
break;
}
}
if (!configLoaded) {
throw new RayException("Some processes that the driver needs to connect to have " +
"not registered with Redis. Have you run 'ray start' on this node?");
}
}
private void loadConfigFromGcs() {
rayConfig.rayletConfigParameters.clear();
for (Map.Entry<String, String> entry : gcsClient.getInternalConfig().entrySet()) {
rayConfig.rayletConfigParameters.put(entry.getKey(), entry.getValue());
}
}
@Override
public void start() {
if (rayConfig.getRedisAddress() == null) {
manager = new RunManager(rayConfig);
manager.startRayProcesses(true);
try {
if (rayConfig.workerMode == WorkerType.DRIVER && rayConfig.getRedisAddress() == null) {
// Set it to true before `RunManager.startRayHead` so `Ray.shutdown()` can still kill
// Ray processes even if `Ray.init()` failed.
startRayHead = true;
RunManager.startRayHead(rayConfig);
}
Preconditions.checkNotNull(rayConfig.getRedisAddress());
updateSessionDir();
// Expose ray ABI symbols which may be depended by other shared
// libraries such as libstreaming_java.so.
// See BUILD.bazel:libcore_worker_library_java.so
Preconditions.checkNotNull(rayConfig.sessionDir);
JniUtils.loadLibrary(rayConfig.sessionDir, BinaryFileUtil.CORE_WORKER_JAVA_LIBRARY, true);
if (rayConfig.workerMode == WorkerType.DRIVER) {
RunManager.getAddressInfoAndFillConfig(rayConfig);
}
gcsClient = new GcsClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
loadConfigFromGcs();
if (rayConfig.getJobId() == JobId.NIL) {
rayConfig.setJobId(gcsClient.nextJobId());
}
int numWorkersPerProcess =
rayConfig.workerMode == WorkerType.DRIVER ? 1 : rayConfig.numWorkersPerProcess;
byte[] serializedJobConfig = null;
if (rayConfig.workerMode == WorkerType.DRIVER) {
JobConfig.Builder jobConfigBuilder =
JobConfig.newBuilder()
.setNumJavaWorkersPerProcess(rayConfig.numWorkersPerProcess)
.addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker)
.putAllWorkerEnv(rayConfig.workerEnv)
.addAllCodeSearchPath(rayConfig.codeSearchPath);
serializedJobConfig = jobConfigBuilder.build().toByteArray();
}
Map<String, String> rayletConfigStringMap = new HashMap<>();
for (Map.Entry<String, Object> entry : rayConfig.rayletConfigParameters.entrySet()) {
rayletConfigStringMap.put(entry.getKey(), entry.getValue().toString());
}
nativeInitialize(rayConfig.workerMode.getNumber(),
rayConfig.nodeIp, rayConfig.getNodeManagerPort(),
rayConfig.workerMode == WorkerType.DRIVER ? System.getProperty("user.dir") : "",
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName,
(rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(),
new GcsClientOptions(rayConfig), numWorkersPerProcess,
rayConfig.logDir, rayletConfigStringMap, serializedJobConfig);
taskExecutor = new NativeTaskExecutor(this);
workerContext = new NativeWorkerContext();
objectStore = new NativeObjectStore(workerContext, shutdownLock);
taskSubmitter = new NativeTaskSubmitter();
LOGGER.debug("RayNativeRuntime started with store {}, raylet {}",
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName);
} catch (Exception e) {
if (startRayHead) {
try {
RunManager.stopRay();
} catch (Exception e2) {
// Ignore
}
}
throw e;
}
gcsClient = new GcsClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
if (rayConfig.getJobId() == JobId.NIL) {
rayConfig.setJobId(gcsClient.nextJobId());
}
int numWorkersPerProcess =
rayConfig.workerMode == WorkerType.DRIVER ? 1 : rayConfig.numWorkersPerProcess;
byte[] serializedJobConfig = null;
if (rayConfig.workerMode == WorkerType.DRIVER) {
JobConfig.Builder jobConfigBuilder =
JobConfig.newBuilder()
.setNumJavaWorkersPerProcess(rayConfig.numWorkersPerProcess)
.addAllJvmOptions(rayConfig.jvmOptionsForJavaWorker)
.putAllWorkerEnv(rayConfig.workerEnv)
.addAllCodeSearchPath(rayConfig.codeSearchPath);
serializedJobConfig = jobConfigBuilder.build().toByteArray();
}
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
nativeInitialize(rayConfig.workerMode.getNumber(),
rayConfig.nodeIp, rayConfig.getNodeManagerPort(),
rayConfig.workerMode == WorkerType.DRIVER ? System.getProperty("user.dir") : "",
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName,
(rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(),
new GcsClientOptions(rayConfig), numWorkersPerProcess,
rayConfig.logDir, rayConfig.rayletConfigParameters, serializedJobConfig);
taskExecutor = new NativeTaskExecutor(this);
workerContext = new NativeWorkerContext();
objectStore = new NativeObjectStore(workerContext, shutdownLock);
taskSubmitter = new NativeTaskSubmitter();
LOGGER.debug("RayNativeRuntime started with store {}, raylet {}",
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName);
}
@Override
@@ -183,27 +155,21 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
try {
if (rayConfig.workerMode == WorkerType.DRIVER) {
nativeShutdown();
if (null != manager) {
manager.cleanup();
manager = null;
if (startRayHead) {
startRayHead = false;
RunManager.stopRay();
}
}
if (null != gcsClient) {
gcsClient.destroy();
gcsClient = null;
}
RayConfig.reset();
LOGGER.debug("RayNativeRuntime shutdown");
} finally {
writeLock.unlock();
}
}
// For test purpose only
public RunManager getRunManager() {
return manager;
}
@Override
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
Preconditions.checkArgument(Double.compare(capacity, 0) >= 0);
@@ -2,7 +2,6 @@ package io.ray.runtime.config;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.typesafe.config.Config;
import com.typesafe.config.ConfigException;
@@ -12,17 +11,15 @@ import com.typesafe.config.ConfigValue;
import io.ray.api.id.JobId;
import io.ray.runtime.generated.Common.WorkerType;
import io.ray.runtime.util.NetworkUtil;
import io.ray.runtime.util.ResourceUtil;
import java.io.File;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.lang3.BooleanUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.math.NumberUtils;
/**
* Configurations of Ray runtime.
@@ -33,13 +30,6 @@ public class RayConfig {
public static final String DEFAULT_CONFIG_FILE = "ray.default.conf";
public static final String CUSTOM_CONFIG_FILE = "ray.conf";
private static final Random RANDOM = new Random();
private static final DateTimeFormatter DATE_TIME_FORMATTER =
DateTimeFormatter.ofPattern("yyyy-MM-dd_HH-mm-ss");
private static final String DEFAULT_TEMP_DIR = "/tmp/ray";
private Config config;
/**
@@ -48,54 +38,25 @@ public class RayConfig {
public final String nodeIp;
public final WorkerType workerMode;
public final RunMode runMode;
public final Map<String, Double> resources;
private JobId jobId;
public String sessionDir;
public String logDir;
public final List<String> libraryPath;
public final List<String> classpath;
public final List<String> jvmParameters;
private String redisAddress;
private String redisIp;
private Integer redisPort;
public final int headRedisPort;
public final int[] redisShardPorts;
public final int numberRedisShards;
public final String headRedisPassword;
public final String redisPassword;
// RPC socket name of object store.
public String objectStoreSocketName;
public final Long objectStoreSize;
// RPC socket name of Raylet.
public String rayletSocketName;
// Listening port for node manager.
public int nodeManagerPort;
public final Map<String, String> rayletConfigParameters;
public final Map<String, Object> rayletConfigParameters;
public List<String> codeSearchPath;
public final String pythonWorkerCommand;
public final List<String> codeSearchPath;
private static volatile RayConfig instance = null;
public static RayConfig getInstance() {
if (instance == null) {
synchronized (RayConfig.class) {
if (instance == null) {
instance = RayConfig.create();
}
}
}
return instance;
}
public static void reset() {
synchronized (RayConfig.class) {
instance = null;
}
}
public final List<String> headArgs;
public final int numWorkersPerProcess;
@@ -140,15 +101,6 @@ public class RayConfig {
} else {
nodeIp = NetworkUtil.getIpAddress(null);
}
// Resources.
resources = ResourceUtil.getResourcesMapFromString(
config.getString("ray.resources"));
if (isDriver) {
if (!resources.containsKey("CPU")) {
int numCpu = Runtime.getRuntime().availableProcessors();
resources.put("CPU", numCpu * 1.0);
}
}
// Job id.
String jobId = config.getString("ray.job.id");
if (!jobId.isEmpty()) {
@@ -168,25 +120,16 @@ public class RayConfig {
}
}
workerEnv = workerEnvBuilder.build();
updateSessionDir();
// Object store configurations.
objectStoreSize = config.getBytes("ray.object-store.size");
updateSessionDir(null);
// Library path.
libraryPath = config.getStringList("ray.library.path");
// Custom classpath.
classpath = config.getStringList("ray.classpath");
// Custom worker jvm parameters.
if (config.hasPath("ray.worker.jvm-parameters")) {
jvmParameters = config.getStringList("ray.worker.jvm-parameters");
} else {
jvmParameters = ImmutableList.of();
// Object store socket name.
if (config.hasPath("ray.object-store.socket-name")) {
objectStoreSocketName = config.getString("ray.object-store.socket-name");
}
if (config.hasPath("ray.worker.python-command")) {
pythonWorkerCommand = config.getString("ray.worker.python-command");
} else {
pythonWorkerCommand = null;
// Raylet socket name.
if (config.hasPath("ray.raylet.socket-name")) {
rayletSocketName = config.getString("ray.raylet.socket-name");
}
// Redis configurations.
@@ -198,17 +141,6 @@ public class RayConfig {
this.redisAddress = null;
}
if (config.hasPath("ray.redis.head-port")) {
headRedisPort = config.getInt("ray.redis.head-port");
} else {
headRedisPort = NetworkUtil.getUnusedPort();
}
numberRedisShards = config.getInt("ray.redis.shard-number");
redisShardPorts = new int[numberRedisShards];
for (int i = 0; i < numberRedisShards; i++) {
redisShardPorts[i] = NetworkUtil.getUnusedPort();
}
headRedisPassword = config.getString("ray.redis.head-password");
redisPassword = config.getString("ray.redis.password");
// Raylet node manager port.
if (config.hasPath("ray.raylet.node-manager-port")) {
@@ -216,7 +148,6 @@ public class RayConfig {
} else {
Preconditions.checkState(workerMode != WorkerType.WORKER,
"Worker started by raylet should accept the node manager port from raylet.");
nodeManagerPort = NetworkUtil.getUnusedPort();
}
// Raylet parameters.
@@ -224,13 +155,27 @@ public class RayConfig {
Config rayletConfig = config.getConfig("ray.raylet.config");
for (Map.Entry<String, ConfigValue> entry : rayletConfig.entrySet()) {
Object value = entry.getValue().unwrapped();
rayletConfigParameters.put(entry.getKey(), value == null ? "" : value.toString());
if (value != null) {
if (value instanceof String) {
String valueString = (String) value;
Boolean booleanValue = BooleanUtils.toBooleanObject(valueString);
if (booleanValue != null) {
value = booleanValue;
} else if (NumberUtils.isParsable(valueString)) {
value = NumberUtils.createNumber(valueString);
}
}
rayletConfigParameters.put(entry.getKey(), value);
}
}
// Job code search path.
String codeSearchPathString = null;
if (config.hasPath("ray.job.code-search-path")) {
codeSearchPath = Arrays.asList(
config.getString("ray.job.code-search-path").split(":"));
codeSearchPathString = config.getString("ray.job.code-search-path");
}
if (!StringUtils.isEmpty(codeSearchPathString)) {
codeSearchPath = Arrays.asList(codeSearchPathString.split(":"));
} else {
codeSearchPath = Collections.emptyList();
}
@@ -258,6 +203,8 @@ public class RayConfig {
numWorkersPerProcess = config.getInt("ray.job.num-java-workers-per-process");
}
headArgs = config.getStringList("ray.head-args");
// Validate config.
validate();
}
@@ -267,24 +214,12 @@ public class RayConfig {
Preconditions.checkState(this.redisAddress == null, "Redis address was already set");
this.redisAddress = redisAddress;
String[] ipAndPort = redisAddress.split(":");
Preconditions.checkArgument(ipAndPort.length == 2, "Invalid redis address.");
this.redisIp = ipAndPort[0];
this.redisPort = Integer.parseInt(ipAndPort[1]);
}
public String getRedisAddress() {
return redisAddress;
}
public String getRedisIp() {
return redisIp;
}
public Integer getRedisPort() {
return redisPort;
}
public void setJobId(JobId jobId) {
this.jobId = jobId;
}
@@ -298,11 +233,7 @@ public class RayConfig {
}
public void setSessionDir(String sessionDir) {
this.sessionDir = sessionDir;
}
public String getSessionDir() {
return sessionDir;
updateSessionDir(sessionDir);
}
public Config getInternalConfig() {
@@ -312,7 +243,8 @@ public class RayConfig {
/**
* Renders the config value as a HOCON string.
*/
public String render() {
@Override
public String toString() {
// These items might be dynamically generated or mutated at runtime.
// Explicitly include them.
Map<String, Object> dynamic = new HashMap<>();
@@ -321,24 +253,19 @@ public class RayConfig {
dynamic.put("ray.object-store.socket-name", objectStoreSocketName);
dynamic.put("ray.raylet.node-manager-port", nodeManagerPort);
dynamic.put("ray.address", redisAddress);
dynamic.put("ray.job.code-search-path", codeSearchPath);
Config toRender = ConfigFactory.parseMap(dynamic).withFallback(config);
return toRender.root().render(ConfigRenderOptions.concise());
}
private void updateSessionDir() {
private void updateSessionDir(String sessionDir) {
// session dir
if (workerMode == WorkerType.DRIVER) {
final int minBound = 100000;
final int maxBound = 999999;
final String sessionName = String.format("session_%s_%d", DATE_TIME_FORMATTER.format(
LocalDateTime.now()), RANDOM.nextInt(maxBound - minBound) + minBound);
sessionDir = String.format("%s/%s", DEFAULT_TEMP_DIR, sessionName);
} else if (workerMode == WorkerType.WORKER) {
sessionDir = removeTrailingSlash(config.getString("ray.session-dir"));
} else {
throw new RuntimeException("Unknown worker type.");
if (config.hasPath("ray.session-dir")) {
sessionDir = config.getString("ray.session-dir");
}
if (sessionDir != null) {
sessionDir = removeTrailingSlash(sessionDir);
}
this.sessionDir = sessionDir;
// Log dir.
String localLogDir = null;
@@ -350,34 +277,6 @@ public class RayConfig {
} else {
logDir = localLogDir;
}
// Object store socket name.
String localObjectStoreSocketName = null;
if (config.hasPath("ray.object-store.socket-name")) {
localObjectStoreSocketName = config.getString("ray.object-store.socket-name");
}
if (Strings.isNullOrEmpty(localObjectStoreSocketName)) {
objectStoreSocketName = String.format("%s/sockets/object_store", sessionDir);
} else {
objectStoreSocketName = localObjectStoreSocketName;
}
// Raylet socket name.
String localRayletSocketName = null;
if (config.hasPath("ray.raylet.socket-name")) {
localRayletSocketName = config.getString("ray.raylet.socket-name");
}
if (Strings.isNullOrEmpty(localRayletSocketName)) {
rayletSocketName = String.format("%s/sockets/raylet", sessionDir);
} else {
rayletSocketName = localRayletSocketName;
}
}
@Override
public String toString() {
return render();
}
/**
@@ -43,16 +43,6 @@ public class RuntimeContextImpl implements RuntimeContext {
return runtime.getGcsClient().wasCurrentActorRestarted(getCurrentActorId());
}
@Override
public String getRayletSocketName() {
return runtime.getRayConfig().rayletSocketName;
}
@Override
public String getObjectStoreSocketName() {
return runtime.getRayConfig().objectStoreSocketName;
}
@Override
public boolean isSingleProcess() {
return RunMode.SINGLE_PROCESS == runtime.getRayConfig().runMode;
@@ -1,5 +1,6 @@
package io.ray.runtime.gcs;
import com.google.common.base.Preconditions;
import io.ray.runtime.config.RayConfig;
/**
@@ -11,8 +12,10 @@ public class GcsClientOptions {
public String password;
public GcsClientOptions(RayConfig rayConfig) {
ip = rayConfig.getRedisIp();
port = rayConfig.getRedisPort();
String[] ipAndPort = rayConfig.getRedisAddress().split(":");
Preconditions.checkArgument(ipAndPort.length == 2, "Invalid redis address.");
ip = ipAndPort[0];
port = Integer.parseInt(ipAndPort[1]);
password = rayConfig.redisPassword;
}
}
@@ -1,31 +1,21 @@
package io.ray.runtime.runner;
import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.gson.Gson;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.util.BinaryFileUtil;
import io.ray.runtime.util.ResourceUtil;
import java.io.File;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.jedis.Jedis;
/**
* Ray service management on one box.
@@ -34,97 +24,78 @@ public class RunManager {
private static final Logger LOGGER = LoggerFactory.getLogger(RunManager.class);
private static final String WORKER_CLASS = "io.ray.runtime.runner.worker.DefaultWorker";
private static final String SESSION_LATEST = "session_latest";
private RayConfig rayConfig;
private List<Pair<String, Process>> processes;
private static final int KILL_PROCESS_WAIT_TIMEOUT_SECONDS = 1;
public RunManager(RayConfig rayConfig) {
this.rayConfig = rayConfig;
processes = new ArrayList<>();
createTempDirs();
}
public void cleanup() {
// Terminate the processes in the reversed order of creating them.
// Because raylet needs to exit before object store, otherwise it
// cannot exit gracefully.
for (int i = processes.size() - 1; i >= 0; --i) {
Pair<String, Process> pair = processes.get(i);
terminateProcess(pair.getLeft(), pair.getRight());
}
}
public void terminateProcess(String name, Process p) {
int numAttempts = 0;
while (p.isAlive()) {
if (numAttempts == 0) {
LOGGER.debug("Terminating process {}.", name);
p.destroy();
} else {
LOGGER.debug("Terminating process {} forcibly.", name);
p.destroyForcibly();
}
try {
p.waitFor(KILL_PROCESS_WAIT_TIMEOUT_SECONDS, TimeUnit.SECONDS);
} catch (InterruptedException e) {
LOGGER.warn("Got InterruptedException while waiting for process {}" +
" to be terminated.", name);
}
numAttempts++;
}
LOGGER.debug("Process {} is now terminated.", name);
}
private static final Pattern pattern = Pattern.compile("--address='([^']+)'");
/**
* Get processes by name. For test purposes only.
* Start the head node.
*/
public List<Process> getProcesses(String name) {
return processes.stream().filter(pair -> pair.getLeft().equals(name)).map(Pair::getRight)
.collect(Collectors.toList());
}
private void createTempDirs() {
public static void startRayHead(RayConfig rayConfig) {
LOGGER.debug("Starting ray runtime @ {}.", rayConfig.nodeIp);
String codeSearchPath;
if (!rayConfig.codeSearchPath.isEmpty()) {
codeSearchPath = Joiner.on(File.pathSeparator).join(rayConfig.codeSearchPath);
} else {
codeSearchPath = System.getProperty("java.class.path");
}
List<String> command = new ArrayList<>();
command.add("ray");
command.add("start");
command.add("--head");
command.add("--redis-password");
command.add(rayConfig.redisPassword);
command.add("--system-config=" + new Gson().toJson(rayConfig.rayletConfigParameters));
command.add("--code-search-path=" + codeSearchPath);
command.addAll(rayConfig.headArgs);
String output;
try {
FileUtils.forceMkdir(new File(rayConfig.logDir));
FileUtils.forceMkdir(new File(rayConfig.rayletSocketName).getParentFile());
FileUtils.forceMkdir(new File(rayConfig.objectStoreSocketName).getParentFile());
// Remove session_latest first, and then create a new symbolic link for session_latest.
final String parentOfSessionDir = new File(rayConfig.sessionDir).getParent();
final File sessionLatest = new File(
String.format("%s/%s", parentOfSessionDir, SESSION_LATEST));
if (sessionLatest.exists()) {
sessionLatest.delete();
}
Files.createSymbolicLink(
Paths.get(sessionLatest.getAbsolutePath()),
Paths.get(rayConfig.sessionDir));
} catch (IOException e) {
LOGGER.error("Couldn't create temp directories.", e);
throw new RuntimeException(e);
output = runCommand(command);
} catch (Exception e) {
throw new RuntimeException("Failed to start Ray runtime.", e);
}
Matcher matcher = pattern.matcher(output);
if (matcher.find()) {
String redisAddress = matcher.group(1);
rayConfig.setRedisAddress(redisAddress);
} else {
throw new RuntimeException("Redis address is not found. output: " + output);
}
LOGGER.info("Ray runtime started @ {}.", rayConfig.nodeIp);
}
/**
* @return Log files for stdout and stderr.
* Stop ray.
*/
private Pair<File, File> getLogFiles(String logDir, String processName) {
int suffixIndex = 0;
while (true) {
String suffix = suffixIndex == 0 ? "" : "." + suffixIndex;
File stdout = new File(String.format("%s/%s%s.out", logDir, suffix, processName));
File stderr = new File(String.format("%s/%s%s.err", logDir, suffix, processName));
if (!stdout.exists() && !stderr.exists()) {
return ImmutablePair.of(stdout, stderr);
}
suffixIndex += 1;
public static void stopRay() {
List<String> command = new ArrayList<>();
command.add("ray");
command.add("stop");
command.add("--force");
try {
runCommand(command);
} catch (Exception e) {
throw new RuntimeException("Failed to stop ray.", e);
}
}
public static void getAddressInfoAndFillConfig(RayConfig rayConfig) {
// NOTE(kfstorm): This method depends on an internal Python API of ray to get the
// address info of the local node.
String script = String.format("import ray;"
+ " print(ray._private.services.get_address_info_from_redis("
+ "'%s', '%s', redis_password='%s', no_warning=True))",
rayConfig.getRedisAddress(), rayConfig.nodeIp, rayConfig.redisPassword);
List<String> command = Arrays.asList("python", "-c", script);
String output = null;
try {
output = runCommand(command);
JsonObject addressInfo = new JsonParser().parse(output).getAsJsonObject();
rayConfig.rayletSocketName = addressInfo.get("raylet_socket_name").getAsString();
rayConfig.objectStoreSocketName = addressInfo.get("object_store_address").getAsString();
rayConfig.nodeManagerPort = addressInfo.get("node_manager_port").getAsInt();
} catch (Exception e) {
throw new RuntimeException("Failed to get address info. Output: " + output, e);
}
}
@@ -132,284 +103,22 @@ public class RunManager {
* Start a process.
*
* @param command The command to start the process with.
* @param env Environment variables.
* @param name Process name.
*/
private void startProcess(List<String> command, Map<String, String> env, String name) {
private static String runCommand(List<String> command) throws IOException, InterruptedException {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Starting process {} with command: {}", name,
Joiner.on(" ").join(command));
LOGGER.debug("Starting process with command: {}", Joiner.on(" ").join(command));
}
ProcessBuilder builder = new ProcessBuilder(command);
String stdout = "";
String stderr = "";
// Set stdout and stderr paths.
Pair<File, File> logFiles = getLogFiles(rayConfig.logDir, name);
builder.redirectOutput(logFiles.getLeft());
builder.redirectError(logFiles.getRight());
// Set environment variables.
if (env != null && !env.isEmpty()) {
builder.environment().putAll(env);
}
Process p;
try {
p = builder.start();
} catch (IOException e) {
LOGGER.error("Failed to start process " + name, e);
throw new RuntimeException("Failed to start process " + name, e);
}
// Wait 1000 ms and check whether the process is alive.
try {
TimeUnit.MILLISECONDS.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
if (!p.isAlive()) {
String message = String.format("Failed to start %s. Exit code: %d.",
name, p.exitValue());
message += String.format(" Logs are redirected to %s and %s.", stdout, stderr);
throw new RuntimeException(message);
}
processes.add(Pair.of(name, p));
if (LOGGER.isDebugEnabled()) {
String message = String.format("%s process started.", name);
message += String.format(" Logs are redirected to %s and %s.", stdout, stderr);
LOGGER.debug(message);
ProcessBuilder builder = new ProcessBuilder(command).redirectErrorStream(true);
Process p = builder.start();
String output = IOUtils.toString(p.getInputStream(), Charset.defaultCharset());
p.waitFor();
if (p.exitValue() != 0) {
String sb = "The exit value of the process is " + p.exitValue()
+ ". Command: " + Joiner.on(" ").join(command) + "\n"
+ "output:\n" + output;
throw new RuntimeException(sb);
}
return output;
}
/**
* Start all Ray processes on this node.
*
* @param isHead Whether this node is the head node. If true, redis server will be started.
*/
public void startRayProcesses(boolean isHead) {
LOGGER.debug("Starting ray runtime @ {}.", rayConfig.nodeIp);
try {
if (isHead) {
startGcs();
}
startObjectStore();
startRaylet(isHead);
LOGGER.info("Ray runtime started @ {}.", rayConfig.nodeIp);
} catch (Exception e) {
// Clean up started processes.
cleanup();
LOGGER.error("Failed to start ray runtime.", e);
throw new RuntimeException("Failed to start ray runtime.", e);
}
}
private void startGcs() {
// start primary redis
String primary = startRedisInstance(rayConfig.nodeIp,
rayConfig.headRedisPort, rayConfig.headRedisPassword, null);
rayConfig.setRedisAddress(primary);
try (Jedis client = new Jedis("127.0.0.1", rayConfig.headRedisPort)) {
if (!Strings.isNullOrEmpty(rayConfig.headRedisPassword)) {
client.auth(rayConfig.headRedisPassword);
}
client.set("UseRaylet", "1");
// Set job counter to compute job id.
client.set("JobCounter", "0");
// Register the number of Redis shards in the primary shard, so that clients
// know how many redis shards to expect under RedisShards.
client.set("NumRedisShards", Integer.toString(rayConfig.numberRedisShards));
// Set session dir for this cluster, so that the drivers which connected to this
// cluster will fetch this session dir as its self's session dir.
client.set("session_dir", rayConfig.getSessionDir());
// start redis shards
for (int i = 0; i < rayConfig.numberRedisShards; i++) {
String shard = startRedisInstance(rayConfig.nodeIp,
rayConfig.redisShardPorts[i], rayConfig.headRedisPassword, i);
client.rpush("RedisShards", shard);
}
}
// start gcs server
String redisPasswordOption = "";
if (!Strings.isNullOrEmpty(rayConfig.headRedisPassword)) {
redisPasswordOption = rayConfig.headRedisPassword;
}
// See `src/ray/gcs/gcs_server/gcs_server_main.cc` for the meaning of each parameter.
final File gcsServerFile = BinaryFileUtil.getNativeFile(
rayConfig.sessionDir, BinaryFileUtil.GCS_SERVER_BINARY_NAME);
Preconditions.checkState(gcsServerFile.setExecutable(true));
List<String> command = ImmutableList.of(
gcsServerFile.getAbsolutePath(),
String.format("--redis_address=%s", rayConfig.getRedisIp()),
String.format("--redis_port=%d", rayConfig.getRedisPort()),
String.format("--config_list=%s",
rayConfig.rayletConfigParameters.entrySet().stream()
.map(entry -> entry.getKey() + "," + entry.getValue()).collect(Collectors
.joining(","))),
String.format("--redis_password=%s", redisPasswordOption)
);
startProcess(command, null, "gcs_server");
}
private String startRedisInstance(String ip, int port, String password, Integer shard) {
final File redisServerFile = BinaryFileUtil.getNativeFile(
rayConfig.sessionDir, BinaryFileUtil.REDIS_SERVER_BINARY_NAME);
Preconditions.checkState(redisServerFile.setExecutable(true));
// The redis module file.
File redisModule = BinaryFileUtil.getNativeFile(
rayConfig.sessionDir, BinaryFileUtil.REDIS_MODULE_LIBRARY_NAME);
Preconditions.checkState(redisModule.setExecutable(true));
List<String> command = Lists.newArrayList(
// The redis-server executable file.
redisServerFile.getAbsolutePath(),
"--protected-mode",
"no",
"--port",
String.valueOf(port),
"--loglevel",
"warning",
"--loadmodule",
// The redis module file.
redisModule.getAbsolutePath()
);
if (!Strings.isNullOrEmpty(password)) {
command.add("--requirepass ");
command.add(password);
}
String name = shard == null ? "redis" : "redis-shard_" + shard;
startProcess(command, null, name);
try (Jedis client = new Jedis("127.0.0.1", port)) {
if (!Strings.isNullOrEmpty(password)) {
client.auth(password);
}
// Configure Redis to only generate notifications for the export keys.
client.configSet("notify-keyspace-events", "Kl");
// Put a time stamp in Redis to indicate when it was started.
client.set("redis_start_time", LocalDateTime.now().toString());
}
return ip + ":" + port;
}
private void startRaylet(boolean isHead) throws IOException {
int hardwareConcurrency = Runtime.getRuntime().availableProcessors();
int maximumStartupConcurrency = Math.max(1,
Math.min(rayConfig.resources.getOrDefault("CPU", 0.0).intValue(), hardwareConcurrency));
String redisPasswordOption = "";
if (!Strings.isNullOrEmpty(rayConfig.headRedisPassword)) {
redisPasswordOption = rayConfig.headRedisPassword;
}
// See `src/ray/raylet/main.cc` for the meaning of each parameter.
final File rayletFile = BinaryFileUtil.getNativeFile(
rayConfig.sessionDir, BinaryFileUtil.RAYLET_BINARY_NAME);
Preconditions.checkState(rayletFile.setExecutable(true));
List<String> command = ImmutableList.of(
rayletFile.getAbsolutePath(),
String.format("--raylet_socket_name=%s", rayConfig.rayletSocketName),
String.format("--store_socket_name=%s", rayConfig.objectStoreSocketName),
String.format("--object_manager_port=%d", 0), // The object manager port.
// The node manager port.
String.format("--node_manager_port=%d", rayConfig.getNodeManagerPort()),
String.format("--node_ip_address=%s", rayConfig.nodeIp),
String.format("--redis_address=%s", rayConfig.getRedisIp()),
String.format("--redis_port=%d", rayConfig.getRedisPort()),
String.format("--num_initial_workers=%d", 0), // number of initial workers
String.format("--maximum_startup_concurrency=%d", maximumStartupConcurrency),
String.format("--static_resource_list=%s",
ResourceUtil.getResourcesStringFromMap(rayConfig.resources)),
String.format("--config_list=%s", rayConfig.rayletConfigParameters.entrySet().stream()
.map(entry -> entry.getKey() + "," + entry.getValue())
.collect(Collectors.joining(","))),
String.format("--python_worker_command=%s", buildPythonWorkerCommand()),
String.format("--java_worker_command=%s", buildWorkerCommand()),
String.format("--redis_password=%s", redisPasswordOption),
isHead ? "--head_node" : ""
);
startProcess(command, null, "raylet");
}
private String concatPath(Stream<String> stream) {
// TODO (hchen): Right now, raylet backend doesn't support worker command with spaces.
// Thus, we have to drop some some paths until that is fixed.
return stream.filter(s -> !s.contains(" ")).collect(Collectors.joining(":"));
}
private String buildWorkerCommand() throws IOException {
List<String> cmd = new ArrayList<>();
cmd.add("java");
cmd.add("-classpath");
// Generate classpath based on current classpath + user-defined classpath.
String classpath = concatPath(Stream.concat(
rayConfig.classpath.stream(),
Stream.of(System.getProperty("java.class.path").split(":"))
));
cmd.add(classpath);
// Write current config to a file, and set the file path as Java worker's config file.
// This allows users to set worker config by setting driver's system properties.
File workerConfigFile = new File(rayConfig.sessionDir + "/java_worker.conf");
FileUtils.write(workerConfigFile, rayConfig.render(), Charset.defaultCharset());
cmd.add("-Dray.config-file=" + workerConfigFile.getAbsolutePath());
if (!rayConfig.codeSearchPath.isEmpty()) {
cmd.add("-Dray.job.code-search-path=" +
String.join(":", rayConfig.codeSearchPath));
}
cmd.add("RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER");
cmd.addAll(rayConfig.jvmParameters);
// jvm options
cmd.add("RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER");
// Main class
cmd.add(WORKER_CLASS);
String command = Joiner.on(" ").join(cmd);
LOGGER.debug("Worker command is: {}", command);
return command;
}
private void startObjectStore() {
final File objectStoreFile = BinaryFileUtil.getNativeFile(
rayConfig.sessionDir, BinaryFileUtil.PLASMA_STORE_SERVER_BINARY_NAME);
Preconditions.checkState(objectStoreFile.setExecutable(true));
List<String> command = ImmutableList.of(
// The plasma store executable file.
objectStoreFile.getAbsolutePath(),
"-s",
rayConfig.objectStoreSocketName,
"-m",
rayConfig.objectStoreSize.toString()
);
startProcess(command, null, "plasma_store");
}
private String buildPythonWorkerCommand() {
// disable python worker start from raylet, which starts from java
if (rayConfig.pythonWorkerCommand == null) {
return "";
}
List<String> cmd = new ArrayList<>();
cmd.add(rayConfig.pythonWorkerCommand);
cmd.add("--node-ip-address=" + rayConfig.nodeIp);
cmd.add("--object-store-name=" + rayConfig.objectStoreSocketName);
cmd.add("--raylet-name=" + rayConfig.rayletSocketName);
cmd.add("--address=" + rayConfig.getRedisAddress());
String command = cmd.stream().collect(Collectors.joining(" "));
LOGGER.debug("python worker command: {}", command);
return command;
}
}
@@ -12,15 +12,6 @@ import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.SystemUtils;
public class BinaryFileUtil {
public static final String REDIS_SERVER_BINARY_NAME = "redis-server";
public static final String GCS_SERVER_BINARY_NAME = "gcs_server";
public static final String PLASMA_STORE_SERVER_BINARY_NAME = "plasma_store_server";
public static final String RAYLET_BINARY_NAME = "raylet";
public static final String REDIS_MODULE_LIBRARY_NAME = "libray_redis_module.so";
public static final String CORE_WORKER_JAVA_LIBRARY = "core_worker_library_java";
@@ -2,8 +2,9 @@ package io.ray.runtime.util;
import com.google.common.collect.Sets;
import com.sun.jna.NativeLibrary;
import io.ray.runtime.config.RayConfig;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -11,17 +12,7 @@ import org.slf4j.LoggerFactory;
public class JniUtils {
private static final Logger LOGGER = LoggerFactory.getLogger(JniUtils.class);
private static Set<String> loadedLibs = Sets.newHashSet();
/**
* Loads the native library specified by the <code>libraryName</code> argument.
* The <code>libraryName</code> argument must not contain any platform specific
* prefix, file extension or path.
*
* @param libraryName the name of the library.
*/
public static synchronized void loadLibrary(String libraryName) {
loadLibrary(libraryName, false);
}
private static String defaultDestDir;
/**
* Loads the native library specified by the <code>libraryName</code> argument.
@@ -29,15 +20,51 @@ public class JniUtils {
* prefix, file extension or path.
*
* @param libraryName the name of the library.
*/
public static synchronized void loadLibrary(String libraryName) {
loadLibrary(getDefaultDestDir(), libraryName);
}
/**
* Loads the native library specified by the <code>libraryName</code> argument.
* The <code>libraryName</code> argument must not contain any platform specific
* prefix, file extension or path.
*
* @param libraryName the name of the library.
* @param exportSymbols export symbols of library so that it can be used by other libs.
*/
public static synchronized void loadLibrary(String libraryName, boolean exportSymbols) {
loadLibrary(getDefaultDestDir(), libraryName, exportSymbols);
}
/**
* Loads the native library specified by the <code>libraryName</code> argument.
* The <code>libraryName</code> argument must not contain any platform specific
* prefix, file extension or path.
*
* @param destDir The destination dir the library to be extracted.
* @param libraryName the name of the library.
*/
public static synchronized void loadLibrary(String destDir, String libraryName) {
loadLibrary(destDir, libraryName, false);
}
/**
* Loads the native library specified by the <code>libraryName</code> argument.
* The <code>libraryName</code> argument must not contain any platform specific
* prefix, file extension or path.
*
* @param destDir The destination dir the library to be extracted.
* @param libraryName the name of the library.
* @param exportSymbols export symbols of library so that it can be used by other libs.
*/
public static synchronized void loadLibrary(String destDir, String libraryName,
boolean exportSymbols) {
if (!loadedLibs.contains(libraryName)) {
LOGGER.debug("Loading native library {}.", libraryName);
// Load native library.
String fileName = System.mapLibraryName(libraryName);
final String sessionDir = RayConfig.getInstance().sessionDir;
final File file = BinaryFileUtil.getNativeFile(sessionDir, fileName);
final File file = BinaryFileUtil.getNativeFile(destDir, fileName);
if (exportSymbols) {
// Expose library symbols using RTLD_GLOBAL which may be depended by other shared
@@ -50,4 +77,17 @@ public class JniUtils {
}
}
/**
* Cache the result so that multiple calls return the same dest dir.
*/
private static synchronized String getDefaultDestDir() {
if (defaultDestDir == null) {
try {
defaultDestDir = Files.createTempDirectory("native_libs").toString();
} catch (IOException e) {
throw new RuntimeException(e);
}
}
return defaultDestDir;
}
}
@@ -18,9 +18,6 @@ ray {
// `CLUSTER`: Ray is running on one or more nodes, with multiple processes.
run-mode: CLUSTER
// Available resources on this node, for example "CPU:4,GPU:0".
resources: ""
// Configuration items about job.
job {
// If worker.mode is DRIVER, specify the job id.
@@ -56,40 +53,12 @@ ray {
max-backup-files: 10
}
// Custom worker jvm parameters.
worker.jvm-parameters: []
// Custom `java.library.path`
// Note, do not use `dir1:dir2` format, put each dir as a list item.
library.path: []
// Custom classpath.
// Note, do not use `dir1:dir2` format, put each dir as a list item.
classpath = []
// ----------------------
// Redis configurations
// ----------------------
redis {
// If `redis.server` isn't provided, which port we should use to start redis server.
// If `head-port` is not provided, it will be generated randomly.
// head-port: 6379
// Below passwords should be consistent with the one defined in python/ray/ray_constants.py.
// The password used to start the redis server on the head node.
head-password: "5241590000000000"
// The password used to connect to the redis server.
password: "5241590000000000"
// If `redis.server` isn't provided, how many Redis shards we should start in addition to the
// primary Redis shard. The ports of these shards will be `head-port + 1`, `head-port + 2`, etc.
shard-number: 1
}
// ----------------------------
// Object store configurations
// ----------------------------
object-store {
// Initial size of the object store.
size: 10 MB
}
// ----------------------------
@@ -97,12 +66,14 @@ ray {
// ----------------------------
raylet {
// See src/ray/ray_config_def.h for options.
// Below section takes effect only if Ray head is started by a driver.
config {
// TODO(zhuohan): enable this for java
put_small_object_in_memory_store: false
}
}
// Whether we enable job manager to submit and manage job.
enable-job-manager: false
// Below args will be appended as parameters of the `ray start` command.
// It takes effect only if Ray head is started by a driver.
head-args: []
}
@@ -2,6 +2,8 @@ package io.ray.runtime.config;
import io.ray.runtime.generated.Common.WorkerType;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -11,37 +13,39 @@ public class RayConfigTest {
@Test
public void testCreateRayConfig() {
Map<String, String> rayletConfig = new HashMap<>();
rayletConfig.put("one", "1");
rayletConfig.put("zero", "0");
rayletConfig.put("positive-integer", "123");
rayletConfig.put("negative-integer", "-123");
rayletConfig.put("float", "-123.456");
rayletConfig.put("true", "true");
rayletConfig.put("false", "false");
rayletConfig.put("string", "abc");
try {
System.setProperty("ray.job.code-search-path", "path/to/ray/job/resource/path");
for (Map.Entry<String, String> entry : rayletConfig.entrySet()) {
System.setProperty("ray.raylet.config." + entry.getKey(), entry.getValue());
}
RayConfig rayConfig = RayConfig.create();
Assert.assertEquals(WorkerType.DRIVER, rayConfig.workerMode);
Assert.assertEquals(Collections.singletonList("path/to/ray/job/resource/path"),
rayConfig.codeSearchPath);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("one"), 1);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("zero"), 0);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("positive-integer"), 123);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("negative-integer"), -123);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("float"), -123.456f);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("true"), true);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("false"), false);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("string"), "abc");
} finally {
// Unset system properties.
System.clearProperty("ray.job.code-search-path");
}
}
@Test
public void testGenerateHeadPortRandomly() {
boolean isSame = true;
final int port1 = RayConfig.create().headRedisPort;
// If we the 2 ports are the same, let's retry.
// This is used to avoid any flaky chance.
for (int i = 0; i < NUM_RETRIES; ++i) {
final int port2 = RayConfig.create().headRedisPort;
if (port1 != port2) {
isSame = false;
break;
for (String key : rayletConfig.keySet()) {
System.clearProperty("ray.raylet.config." + key);
}
}
Assert.assertFalse(isSame);
}
@Test
public void testSpecifyHeadPort() {
System.setProperty("ray.redis.head-port", "11111");
Assert.assertEquals(RayConfig.create().headRedisPort, 11111);
}
}
+2 -2
View File
@@ -44,8 +44,8 @@ TEST_ARGS=(-Dray.raylet.config.num_workers_per_process_java=10 -Dray.job.num-jav
echo "Running tests under cluster mode."
# TODO(hchen): Ideally, we should use the following bazel command to run Java tests. However, if there're skipped tests,
# TestNG will exit with code 2. And bazel treats it as test failure.
# bazel test //java:all_tests --action_env=ENABLE_MULTI_LANGUAGE_TESTS=1 --config=ci || cluster_exit_code=$?
ENABLE_MULTI_LANGUAGE_TESTS=1 run_testng java -cp "$ROOT_DIR"/../bazel-bin/java/all_tests_deploy.jar "${TEST_ARGS[@]}" org.testng.TestNG -d /tmp/ray_java_test_output "$ROOT_DIR"/testng.xml
# bazel test //java:all_tests --config=ci || cluster_exit_code=$?
run_testng java -cp "$ROOT_DIR"/../bazel-bin/java/all_tests_deploy.jar "${TEST_ARGS[@]}" org.testng.TestNG -d /tmp/ray_java_test_output "$ROOT_DIR"/testng.xml
echo "Running tests under single-process mode."
# bazel test //java:all_tests --jvmopt="-Dray.run-mode=SINGLE_PROCESS" --config=ci || single_exit_code=$?
@@ -1,127 +0,0 @@
package io.ray.test;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.gson.Gson;
import io.ray.api.Ray;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.util.NetworkUtil;
import java.io.File;
import java.lang.ProcessBuilder.Redirect;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
@Test(groups = {"cluster", "multiLanguage"})
public abstract class BaseMultiLanguageTest {
private static final Logger LOGGER = LoggerFactory.getLogger(BaseMultiLanguageTest.class);
private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/test/plasma_store_socket";
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket";
/**
* Execute an external command.
*
* @return Whether the command succeeded.
*/
private boolean executeCommand(List<String> command, int waitTimeoutSeconds,
Map<String, String> env) {
try {
LOGGER.info("Executing command: {}", String.join(" ", command));
ProcessBuilder processBuilder = new ProcessBuilder(command).redirectOutput(Redirect.INHERIT)
.redirectError(Redirect.INHERIT);
for (Entry<String, String> entry : env.entrySet()) {
processBuilder.environment().put(entry.getKey(), entry.getValue());
}
Process process = processBuilder.start();
process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
return process.exitValue() == 0;
} catch (Exception e) {
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
}
}
@BeforeClass(alwaysRun = true, inheritGroups = false)
public void setUp() {
// Delete existing socket files.
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
File file = new File(socket);
if (file.exists()) {
file.delete();
}
}
String nodeManagerPort = String.valueOf(NetworkUtil.getUnusedPort());
// jars in the `ray` wheel doesn't contains test classes, so we add test classes explicitly.
// Since mvn test classes contains `test` in path and bazel test classes is located at a jar
// with `test` included in the name, we can check classpath `test` to filter out test classes.
List<String> classpath = Stream.of(System.getProperty("java.class.path").split(":"))
.filter(s -> !s.contains(" ") && s.contains("test"))
.collect(Collectors.toList());
// Start ray cluster.
List<String> startCommand = Arrays.asList(
"ray",
"start",
"--head",
"--port=6379",
"--min-worker-port=0",
"--max-worker-port=0",
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
String.format("--node-manager-port=%s", nodeManagerPort),
"--load-code-from-local",
"--system-config=" + new Gson().toJson(RayConfig.create().rayletConfigParameters),
"--code-search-path=" + String.join(":", classpath)
);
if (!executeCommand(startCommand, 10, getRayStartEnv())) {
throw new RuntimeException("Couldn't start ray cluster.");
}
// Connect to the cluster.
Assert.assertFalse(Ray.isInitialized());
System.setProperty("ray.address", "127.0.0.1:6379");
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
System.setProperty("ray.raylet.node-manager-port", nodeManagerPort);
Ray.init();
}
/**
* @return The environment variables needed for the `ray start` command.
*/
protected Map<String, String> getRayStartEnv() {
return ImmutableMap.of();
}
@AfterClass(alwaysRun = true, inheritGroups = false)
public void tearDown() {
// Disconnect to the cluster.
Ray.shutdown();
System.clearProperty("ray.address");
System.clearProperty("ray.object-store.socket-name");
System.clearProperty("ray.raylet.socket-name");
System.clearProperty("ray.raylet.node-manager-port");
// Stop ray cluster.
final List<String> stopCommand = ImmutableList.of(
"ray",
"stop"
);
if (!executeCommand(stopCommand, 10, ImmutableMap.of())) {
throw new RuntimeException("Couldn't stop ray cluster");
}
}
}
@@ -1,46 +1,22 @@
package io.ray.test;
import com.google.common.collect.ImmutableList;
import io.ray.api.Ray;
import java.io.File;
import java.lang.reflect.Method;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
public class BaseTest {
private static final Logger LOGGER = LoggerFactory.getLogger(BaseTest.class);
private List<File> filesToDelete = ImmutableList.of();
@BeforeMethod(alwaysRun = true)
public void setUpBase(Method method) {
Assert.assertFalse(Ray.isInitialized());
Ray.init();
// These files need to be deleted after each test case.
filesToDelete = ImmutableList.of(
new File(Ray.getRuntimeContext().getRayletSocketName()),
new File(Ray.getRuntimeContext().getObjectStoreSocketName()),
// TODO(pcm): This is a workaround for the issue described
// in the PR description of https://github.com/ray-project/ray/pull/5450
// and should be fixed properly.
new File("/tmp/ray/test/raylet_socket")
);
// Make sure the files will be deleted even if the test doesn't exit gracefully.
filesToDelete.forEach(File::deleteOnExit);
}
@AfterMethod(alwaysRun = true)
public void tearDownBase() {
Ray.shutdown();
for (File file : filesToDelete) {
file.delete();
}
}
}
@@ -1,7 +1,6 @@
package io.ray.test;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
@@ -19,17 +18,19 @@ import java.io.InputStream;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
@Test(groups = {"cluster"})
public class CrossLanguageInvocationTest extends BaseTest {
private static final String PYTHON_MODULE = "test_cross_language_invocation";
@Override
protected Map<String, String> getRayStartEnv() {
@BeforeClass
public void beforeClass() {
// Delete and re-create the temp dir.
File tempDir = new File(
System.getProperty("java.io.tmpdir") + File.separator + "ray_cross_language_test");
@@ -48,7 +49,14 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
throw new RuntimeException(e);
}
return ImmutableMap.of("PYTHONPATH", tempDir.getAbsolutePath());
System.setProperty("ray.job.code-search-path",
System.getProperty("java.class.path") + File.pathSeparator
+ tempDir.getAbsolutePath());
}
@AfterClass
public void afterClass() {
System.clearProperty("ray.job.code-search-path");
}
@Test
@@ -16,12 +16,12 @@ public class GcsClientTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.resources", "A:8");
System.setProperty("ray.head-args.0", "--resources={\"A\":8}");
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.resources");
System.clearProperty("ray.head-args.0");
}
public void testGetAllNodeInfo() {
@@ -17,12 +17,12 @@ public class GlobalGcTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.object-store.size", "140 MB");
System.setProperty("ray.head-args.0", "--object-store-memory=" + 140L * 1024 * 1024);
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.object-store.size");
System.clearProperty("ray.head-args.0");
}
public static class LargeObjectWithCyclicRef {
@@ -5,7 +5,8 @@ import io.ray.api.Ray;
import org.testng.Assert;
import org.testng.annotations.Test;
public class MultiLanguageClusterTest extends BaseMultiLanguageTest {
@Test(groups = {"cluster"})
public class MultiLanguageClusterTest extends BaseTest {
public static String echo(String word) {
return word;
@@ -18,9 +18,6 @@ public class RayAlterSuiteListener implements IAlterSuiteListener {
XmlGroups groups = new XmlGroups();
XmlRun run = new XmlRun();
run.onExclude(excludedGroup);
if (!"1".equals(System.getenv("ENABLE_MULTI_LANGUAGE_TESTS"))) {
run.onExclude("multiLanguage");
}
groups.setRun(run);
suite.setGroups(groups);
}
@@ -8,7 +8,8 @@ import io.ray.runtime.object.ObjectSerializer;
import org.testng.Assert;
import org.testng.annotations.Test;
public class RaySerializerTest extends BaseMultiLanguageTest {
@Test(groups = {"cluster"})
public class RaySerializerTest extends BaseTest {
@Test
public void testSerializePyActor() {
@@ -25,7 +25,9 @@ public class RayletConfigTest extends BaseTest {
public static class TestActor {
public String getConfigValue() {
return TestUtils.getRuntime().getRayConfig().rayletConfigParameters.get(RAY_CONFIG_KEY);
return TestUtils.getRuntime().getRayConfig()
.rayletConfigParameters.get(RAY_CONFIG_KEY)
.toString();
}
}
@@ -11,13 +11,11 @@ public class RedisPasswordTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.redis.head-password", "12345678");
System.setProperty("ray.redis.password", "12345678");
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.redis.head-password");
System.clearProperty("ray.redis.password");
}
@@ -27,12 +27,12 @@ import org.testng.annotations.Test;
public class ReferenceCountingTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.object-store.size", "100 MB");
System.setProperty("ray.head-args.0", "--object-store-memory=" + 100L * 1024 * 1024);
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.object-store.size");
System.clearProperty("ray.head-args.0");
}
/**
@@ -20,12 +20,14 @@ public class ResourcesManagementTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.resources", "CPU:4,RES-A:4");
System.setProperty("ray.head-args.0", "--num-cpus=4");
System.setProperty("ray.head-args.1", "--resources={\"RES-A\":4}");
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.resources");
System.clearProperty("ray.head-args.0");
System.clearProperty("ray.head-args.1");
}
public static Integer echo(Integer number) {
@@ -14,8 +14,6 @@ import org.testng.annotations.Test;
public class RuntimeContextTest extends BaseTest {
private static JobId JOB_ID = getJobId();
private static String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket";
private static String OBJECT_STORE_SOCKET_NAME = "/tmp/ray/test/object_store_socket";
private static JobId getJobId() {
// Must be stable across different processes.
@@ -27,23 +25,16 @@ public class RuntimeContextTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.job.id", JOB_ID.toString());
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
System.setProperty("ray.object-store.socket-name", OBJECT_STORE_SOCKET_NAME);
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.job.id");
System.clearProperty("ray.raylet.socket-name");
System.clearProperty("ray.object-store.socket-name");
}
@Test
public void testRuntimeContextInDriver() {
Assert.assertEquals(JOB_ID, Ray.getRuntimeContext().getCurrentJobId());
Assert.assertEquals(RAYLET_SOCKET_NAME, Ray.getRuntimeContext().getRayletSocketName());
Assert.assertEquals(OBJECT_STORE_SOCKET_NAME,
Ray.getRuntimeContext().getObjectStoreSocketName());
}
public static class RuntimeContextTester {
@@ -51,9 +42,6 @@ public class RuntimeContextTest extends BaseTest {
public String testRuntimeContext(ActorId actorId) {
Assert.assertEquals(JOB_ID, Ray.getRuntimeContext().getCurrentJobId());
Assert.assertEquals(actorId, Ray.getRuntimeContext().getCurrentActorId());
Assert.assertEquals(RAYLET_SOCKET_NAME, Ray.getRuntimeContext().getRayletSocketName());
Assert.assertEquals(OBJECT_STORE_SOCKET_NAME,
Ray.getRuntimeContext().getObjectStoreSocketName());
return "ok";
}
}
+7 -5
View File
@@ -279,7 +279,8 @@ def get_address_info_from_redis_helper(redis_address,
def get_address_info_from_redis(redis_address,
node_ip_address,
num_retries=5,
redis_password=None):
redis_password=None,
no_warning=False):
counter = 0
while True:
try:
@@ -290,10 +291,11 @@ def get_address_info_from_redis(redis_address,
raise
# Some of the information may not be in Redis yet, so wait a little
# bit.
logger.warning(
"Some processes that the driver needs to connect to have "
"not registered with Redis, so retrying. Have you run "
"'ray start' on this node?")
if not no_warning:
logger.warning(
"Some processes that the driver needs to connect to have "
"not registered with Redis, so retrying. Have you run "
"'ray start' on this node?")
time.sleep(1)
counter += 1
+2
View File
@@ -537,6 +537,8 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports,
with cli_logger.group("Next steps"):
cli_logger.print(
"To connect to this Ray runtime from another node, run")
# NOTE(kfstorm): Java driver rely on this line to get the address
# of the cluster. Please be careful when updating this line.
cli_logger.print(
cf.bold(" ray start --address='{}'{}"), redis_address,
f" --redis-password='{redis_password}'"
@@ -25,7 +25,7 @@ extern "C" {
* Class: io_ray_runtime_RayNativeRuntime
* Method: nativeInitialize
* Signature:
* (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;Ljava/util/Map;)V
* (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;Ljava/util/Map;[B)V
*/
JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
JNIEnv *, jclass, jint, jstring, jint, jstring, jstring, jstring, jbyteArray, jobject,
@@ -68,7 +68,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_object_NativeObjectStore_nativeDelete
/*
* Class: io_ray_runtime_object_NativeObjectStore
* Method: nativeAddLocalReference
* Signature: ([B)V
* Signature: ([B[B)V
*/
JNIEXPORT void JNICALL
Java_io_ray_runtime_object_NativeObjectStore_nativeAddLocalReference(JNIEnv *, jclass,
@@ -78,7 +78,7 @@ Java_io_ray_runtime_object_NativeObjectStore_nativeAddLocalReference(JNIEnv *, j
/*
* Class: io_ray_runtime_object_NativeObjectStore
* Method: nativeRemoveLocalReference
* Signature: ([B)V
* Signature: ([B[B)V
*/
JNIEXPORT void JNICALL
Java_io_ray_runtime_object_NativeObjectStore_nativeRemoveLocalReference(JNIEnv *, jclass,
@@ -1,30 +1,16 @@
package io.ray.streaming.api.context;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.gson.Gson;
import io.ray.api.Ray;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.util.NetworkUtil;
import java.io.File;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class ClusterStarter {
private static final Logger LOG = LoggerFactory.getLogger(ClusterStarter.class);
private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/plasma_store_socket";
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/raylet_socket";
static synchronized void startCluster(boolean isCrossLanguage, boolean isLocal) {
static synchronized void startCluster(boolean isLocal) {
Preconditions.checkArgument(!Ray.isInitialized());
RayConfig.reset();
if (!isLocal) {
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
System.setProperty("ray.run-mode", "CLUSTER");
@@ -33,97 +19,13 @@ class ClusterStarter {
System.setProperty("ray.run-mode", "SINGLE_PROCESS");
}
if (!isCrossLanguage) {
Ray.init();
return;
}
// Delete existing socket files.
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
File file = new File(socket);
if (file.exists()) {
LOG.info("Delete existing socket file {}", file);
file.delete();
}
}
String nodeManagerPort = String.valueOf(NetworkUtil.getUnusedPort());
// jars in the `ray` wheel doesn't contains test classes, so we add test classes explicitly.
// Since mvn test classes contains `test` in path and bazel test classes is located at a jar
// with `test` included in the name, we can check classpath `test` to filter out test classes.
String classpath = Stream.of(System.getProperty("java.class.path").split(":"))
.filter(s -> !s.contains(" ") && s.contains("test"))
.collect(Collectors.joining(":"));
String workerOptions = new Gson().toJson(ImmutableList.of("-classpath", classpath));
Map<String, String> config = new HashMap<>(RayConfig.create().rayletConfigParameters);
config.put("num_workers_per_process_java", "1");
// Start ray cluster.
List<String> startCommand = ImmutableList.of(
"ray",
"start",
"--head",
"--port=6379",
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
String.format("--node-manager-port=%s", nodeManagerPort),
"--load-code-from-local",
"--java-worker-options=" + workerOptions,
"--system-config=" + new Gson().toJson(config)
);
if (!executeCommand(startCommand, 10)) {
throw new RuntimeException("Couldn't start ray cluster.");
}
// Connect to the cluster.
System.setProperty("ray.address", "127.0.0.1:6379");
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
System.setProperty("ray.raylet.node-manager-port", nodeManagerPort);
Ray.init();
}
public static synchronized void stopCluster(boolean isCrossLanguage) {
public static synchronized void stopCluster() {
// Disconnect to the cluster.
Ray.shutdown();
System.clearProperty("ray.address");
System.clearProperty("ray.object-store.socket-name");
System.clearProperty("ray.raylet.socket-name");
System.clearProperty("ray.raylet.node-manager-port");
System.clearProperty("ray.raylet.config.num_workers_per_process_java");
System.clearProperty("ray.run-mode");
if (isCrossLanguage) {
// Stop ray cluster.
final List<String> stopCommand = ImmutableList.of(
"ray",
"stop"
);
if (!executeCommand(stopCommand, 10)) {
throw new RuntimeException("Couldn't stop ray cluster");
}
}
}
/**
* Execute an external command.
*
* @return Whether the command succeeded.
*/
private static boolean executeCommand(List<String> command, int waitTimeoutSeconds) {
LOG.info("Executing command: {}", String.join(" ", command));
try {
ProcessBuilder processBuilder = new ProcessBuilder(command)
.redirectOutput(ProcessBuilder.Redirect.INHERIT)
.redirectError(ProcessBuilder.Redirect.INHERIT);
Process process = processBuilder.start();
boolean exit = process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
if (!exit) {
process.destroyForcibly();
}
return process.exitValue() == 0;
} catch (Exception e) {
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
}
}
}
@@ -65,11 +65,10 @@ public class StreamingContext implements Serializable {
if (!Ray.isInitialized()) {
if (Config.MEMORY_CHANNEL.equalsIgnoreCase(jobConfig.get(Config.CHANNEL_TYPE))) {
Preconditions.checkArgument(!jobGraph.isCrossLanguageGraph());
ClusterStarter.startCluster(false, true);
ClusterStarter.startCluster(true);
LOG.info("Created local cluster for job {}.", jobName);
} else {
ClusterStarter.startCluster(jobGraph.isCrossLanguageGraph(), false);
ClusterStarter.startCluster(false);
LOG.info("Created multi process cluster for job {}.", jobName);
}
Runtime.getRuntime().addShutdownHook(new Thread(StreamingContext.this::stop));
@@ -103,7 +102,7 @@ public class StreamingContext implements Serializable {
public void stop() {
if (Ray.isInitialized()) {
ClusterStarter.stopCluster(jobGraph.isCrossLanguageGraph());
ClusterStarter.stopCluster();
}
}
}
@@ -1,6 +1,5 @@
package io.ray.streaming.jobgraph;
import io.ray.streaming.api.Language;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
@@ -138,14 +137,4 @@ public class JobGraph implements Serializable {
}
}
public boolean isCrossLanguageGraph() {
Language language = jobVertices.get(0).getLanguage();
for (JobVertex jobVertex : jobVertices) {
if (jobVertex.getLanguage() != language) {
return true;
}
}
return false;
}
}
@@ -1,6 +1,7 @@
package io.ray.streaming.runtime.transfer;
import io.ray.runtime.RayNativeRuntime;
import io.ray.runtime.util.BinaryFileUtil;
import io.ray.runtime.util.JniUtils;
/**
@@ -10,11 +11,7 @@ import io.ray.runtime.util.JniUtils;
public class TransferHandler {
static {
try {
Class.forName(RayNativeRuntime.class.getName());
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
JniUtils.loadLibrary(BinaryFileUtil.CORE_WORKER_JAVA_LIBRARY, true);
JniUtils.loadLibrary("streaming_java");
}
@@ -1,6 +1,7 @@
package io.ray.streaming.runtime.util;
import io.ray.runtime.RayNativeRuntime;
import io.ray.runtime.util.BinaryFileUtil;
import io.ray.runtime.util.JniUtils;
import java.lang.management.ManagementFactory;
import java.net.InetAddress;
@@ -29,13 +30,7 @@ public class EnvUtil {
}
public static void loadNativeLibraries() {
// Explicitly load `RayNativeRuntime`, to make sure `core_worker_library_java`
// is loaded before `streaming_java`.
try {
Class.forName(RayNativeRuntime.class.getName());
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
JniUtils.loadLibrary(BinaryFileUtil.CORE_WORKER_JAVA_LIBRARY, true);
JniUtils.loadLibrary("streaming_java");
}
@@ -58,11 +58,11 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
void beforeMethod() {
LOGGER.info("beforeTest");
Ray.shutdown();
System.setProperty("ray.resources", "CPU:4,RES-A:4");
System.setProperty("ray.head-args.0", "--num-cpus=4");
System.setProperty("ray.head-args.1", "--resources={\"RES-A\":4}");
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
System.setProperty("ray.run-mode", "CLUSTER");
System.setProperty("ray.redirect-output", "true");
RayConfig.reset();
Ray.init();
}
@@ -71,6 +71,8 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
LOGGER.info("afterTest");
Ray.shutdown();
System.clearProperty("ray.run-mode");
System.clearProperty("ray.head-args.0");
System.clearProperty("ray.head-args.1");
}
@Test(timeOut = 300000)
@@ -78,7 +80,8 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
LOGGER.info("StreamingQueueTest.testReaderWriter run-mode: {}",
System.getProperty("ray.run-mode"));
Ray.shutdown();
System.setProperty("ray.resources", "CPU:4,RES-A:4");
System.setProperty("ray.head-args.0", "--num-cpus=4");
System.setProperty("ray.head-args.1", "--resources={\"RES-A\":4}");
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
System.setProperty("ray.run-mode", "CLUSTER");
@@ -134,7 +137,8 @@ public class StreamingQueueTest extends BaseUnitTest implements Serializable {
@Test(timeOut = 60000)
public void testWordCount() {
Ray.shutdown();
System.setProperty("ray.resources", "CPU:4,RES-A:4");
System.setProperty("ray.head-args.0", "--num-cpus=4");
System.setProperty("ray.head-args.1", "--resources={\"RES-A\":4}");
System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
System.setProperty("ray.run-mode", "CLUSTER");