[Java] Simplify Ray.init() by invoking ray start internally (#10762)

This commit is contained in:
Kai Yang
2020-12-04 14:33:45 +08:00
committed by GitHub
parent 8cebe1e79c
commit 21fcee28f9
39 changed files with 367 additions and 1085 deletions
@@ -1,127 +0,0 @@
package io.ray.test;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.gson.Gson;
import io.ray.api.Ray;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.util.NetworkUtil;
import java.io.File;
import java.lang.ProcessBuilder.Redirect;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
@Test(groups = {"cluster", "multiLanguage"})
public abstract class BaseMultiLanguageTest {
private static final Logger LOGGER = LoggerFactory.getLogger(BaseMultiLanguageTest.class);
private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/test/plasma_store_socket";
private static final String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket";
/**
* Execute an external command.
*
* @return Whether the command succeeded.
*/
private boolean executeCommand(List<String> command, int waitTimeoutSeconds,
Map<String, String> env) {
try {
LOGGER.info("Executing command: {}", String.join(" ", command));
ProcessBuilder processBuilder = new ProcessBuilder(command).redirectOutput(Redirect.INHERIT)
.redirectError(Redirect.INHERIT);
for (Entry<String, String> entry : env.entrySet()) {
processBuilder.environment().put(entry.getKey(), entry.getValue());
}
Process process = processBuilder.start();
process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS);
return process.exitValue() == 0;
} catch (Exception e) {
throw new RuntimeException("Error executing command " + String.join(" ", command), e);
}
}
@BeforeClass(alwaysRun = true, inheritGroups = false)
public void setUp() {
// Delete existing socket files.
for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) {
File file = new File(socket);
if (file.exists()) {
file.delete();
}
}
String nodeManagerPort = String.valueOf(NetworkUtil.getUnusedPort());
// jars in the `ray` wheel doesn't contains test classes, so we add test classes explicitly.
// Since mvn test classes contains `test` in path and bazel test classes is located at a jar
// with `test` included in the name, we can check classpath `test` to filter out test classes.
List<String> classpath = Stream.of(System.getProperty("java.class.path").split(":"))
.filter(s -> !s.contains(" ") && s.contains("test"))
.collect(Collectors.toList());
// Start ray cluster.
List<String> startCommand = Arrays.asList(
"ray",
"start",
"--head",
"--port=6379",
"--min-worker-port=0",
"--max-worker-port=0",
String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME),
String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME),
String.format("--node-manager-port=%s", nodeManagerPort),
"--load-code-from-local",
"--system-config=" + new Gson().toJson(RayConfig.create().rayletConfigParameters),
"--code-search-path=" + String.join(":", classpath)
);
if (!executeCommand(startCommand, 10, getRayStartEnv())) {
throw new RuntimeException("Couldn't start ray cluster.");
}
// Connect to the cluster.
Assert.assertFalse(Ray.isInitialized());
System.setProperty("ray.address", "127.0.0.1:6379");
System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME);
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
System.setProperty("ray.raylet.node-manager-port", nodeManagerPort);
Ray.init();
}
/**
* @return The environment variables needed for the `ray start` command.
*/
protected Map<String, String> getRayStartEnv() {
return ImmutableMap.of();
}
@AfterClass(alwaysRun = true, inheritGroups = false)
public void tearDown() {
// Disconnect to the cluster.
Ray.shutdown();
System.clearProperty("ray.address");
System.clearProperty("ray.object-store.socket-name");
System.clearProperty("ray.raylet.socket-name");
System.clearProperty("ray.raylet.node-manager-port");
// Stop ray cluster.
final List<String> stopCommand = ImmutableList.of(
"ray",
"stop"
);
if (!executeCommand(stopCommand, 10, ImmutableMap.of())) {
throw new RuntimeException("Couldn't stop ray cluster");
}
}
}
@@ -1,46 +1,22 @@
package io.ray.test;
import com.google.common.collect.ImmutableList;
import io.ray.api.Ray;
import java.io.File;
import java.lang.reflect.Method;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
public class BaseTest {
private static final Logger LOGGER = LoggerFactory.getLogger(BaseTest.class);
private List<File> filesToDelete = ImmutableList.of();
@BeforeMethod(alwaysRun = true)
public void setUpBase(Method method) {
Assert.assertFalse(Ray.isInitialized());
Ray.init();
// These files need to be deleted after each test case.
filesToDelete = ImmutableList.of(
new File(Ray.getRuntimeContext().getRayletSocketName()),
new File(Ray.getRuntimeContext().getObjectStoreSocketName()),
// TODO(pcm): This is a workaround for the issue described
// in the PR description of https://github.com/ray-project/ray/pull/5450
// and should be fixed properly.
new File("/tmp/ray/test/raylet_socket")
);
// Make sure the files will be deleted even if the test doesn't exit gracefully.
filesToDelete.forEach(File::deleteOnExit);
}
@AfterMethod(alwaysRun = true)
public void tearDownBase() {
Ray.shutdown();
for (File file : filesToDelete) {
file.delete();
}
}
}
@@ -1,7 +1,6 @@
package io.ray.test;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import io.ray.api.ActorHandle;
import io.ray.api.ObjectRef;
import io.ray.api.PyActorHandle;
@@ -19,17 +18,19 @@ import java.io.InputStream;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.testng.Assert;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
@Test(groups = {"cluster"})
public class CrossLanguageInvocationTest extends BaseTest {
private static final String PYTHON_MODULE = "test_cross_language_invocation";
@Override
protected Map<String, String> getRayStartEnv() {
@BeforeClass
public void beforeClass() {
// Delete and re-create the temp dir.
File tempDir = new File(
System.getProperty("java.io.tmpdir") + File.separator + "ray_cross_language_test");
@@ -48,7 +49,14 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest {
throw new RuntimeException(e);
}
return ImmutableMap.of("PYTHONPATH", tempDir.getAbsolutePath());
System.setProperty("ray.job.code-search-path",
System.getProperty("java.class.path") + File.pathSeparator
+ tempDir.getAbsolutePath());
}
@AfterClass
public void afterClass() {
System.clearProperty("ray.job.code-search-path");
}
@Test
@@ -16,12 +16,12 @@ public class GcsClientTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.resources", "A:8");
System.setProperty("ray.head-args.0", "--resources={\"A\":8}");
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.resources");
System.clearProperty("ray.head-args.0");
}
public void testGetAllNodeInfo() {
@@ -17,12 +17,12 @@ public class GlobalGcTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.object-store.size", "140 MB");
System.setProperty("ray.head-args.0", "--object-store-memory=" + 140L * 1024 * 1024);
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.object-store.size");
System.clearProperty("ray.head-args.0");
}
public static class LargeObjectWithCyclicRef {
@@ -5,7 +5,8 @@ import io.ray.api.Ray;
import org.testng.Assert;
import org.testng.annotations.Test;
public class MultiLanguageClusterTest extends BaseMultiLanguageTest {
@Test(groups = {"cluster"})
public class MultiLanguageClusterTest extends BaseTest {
public static String echo(String word) {
return word;
@@ -18,9 +18,6 @@ public class RayAlterSuiteListener implements IAlterSuiteListener {
XmlGroups groups = new XmlGroups();
XmlRun run = new XmlRun();
run.onExclude(excludedGroup);
if (!"1".equals(System.getenv("ENABLE_MULTI_LANGUAGE_TESTS"))) {
run.onExclude("multiLanguage");
}
groups.setRun(run);
suite.setGroups(groups);
}
@@ -8,7 +8,8 @@ import io.ray.runtime.object.ObjectSerializer;
import org.testng.Assert;
import org.testng.annotations.Test;
public class RaySerializerTest extends BaseMultiLanguageTest {
@Test(groups = {"cluster"})
public class RaySerializerTest extends BaseTest {
@Test
public void testSerializePyActor() {
@@ -25,7 +25,9 @@ public class RayletConfigTest extends BaseTest {
public static class TestActor {
public String getConfigValue() {
return TestUtils.getRuntime().getRayConfig().rayletConfigParameters.get(RAY_CONFIG_KEY);
return TestUtils.getRuntime().getRayConfig()
.rayletConfigParameters.get(RAY_CONFIG_KEY)
.toString();
}
}
@@ -11,13 +11,11 @@ public class RedisPasswordTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.redis.head-password", "12345678");
System.setProperty("ray.redis.password", "12345678");
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.redis.head-password");
System.clearProperty("ray.redis.password");
}
@@ -27,12 +27,12 @@ import org.testng.annotations.Test;
public class ReferenceCountingTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.object-store.size", "100 MB");
System.setProperty("ray.head-args.0", "--object-store-memory=" + 100L * 1024 * 1024);
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.object-store.size");
System.clearProperty("ray.head-args.0");
}
/**
@@ -20,12 +20,14 @@ public class ResourcesManagementTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.resources", "CPU:4,RES-A:4");
System.setProperty("ray.head-args.0", "--num-cpus=4");
System.setProperty("ray.head-args.1", "--resources={\"RES-A\":4}");
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.resources");
System.clearProperty("ray.head-args.0");
System.clearProperty("ray.head-args.1");
}
public static Integer echo(Integer number) {
@@ -14,8 +14,6 @@ import org.testng.annotations.Test;
public class RuntimeContextTest extends BaseTest {
private static JobId JOB_ID = getJobId();
private static String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket";
private static String OBJECT_STORE_SOCKET_NAME = "/tmp/ray/test/object_store_socket";
private static JobId getJobId() {
// Must be stable across different processes.
@@ -27,23 +25,16 @@ public class RuntimeContextTest extends BaseTest {
@BeforeClass
public void setUp() {
System.setProperty("ray.job.id", JOB_ID.toString());
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
System.setProperty("ray.object-store.socket-name", OBJECT_STORE_SOCKET_NAME);
}
@AfterClass
public void tearDown() {
System.clearProperty("ray.job.id");
System.clearProperty("ray.raylet.socket-name");
System.clearProperty("ray.object-store.socket-name");
}
@Test
public void testRuntimeContextInDriver() {
Assert.assertEquals(JOB_ID, Ray.getRuntimeContext().getCurrentJobId());
Assert.assertEquals(RAYLET_SOCKET_NAME, Ray.getRuntimeContext().getRayletSocketName());
Assert.assertEquals(OBJECT_STORE_SOCKET_NAME,
Ray.getRuntimeContext().getObjectStoreSocketName());
}
public static class RuntimeContextTester {
@@ -51,9 +42,6 @@ public class RuntimeContextTest extends BaseTest {
public String testRuntimeContext(ActorId actorId) {
Assert.assertEquals(JOB_ID, Ray.getRuntimeContext().getCurrentJobId());
Assert.assertEquals(actorId, Ray.getRuntimeContext().getCurrentActorId());
Assert.assertEquals(RAYLET_SOCKET_NAME, Ray.getRuntimeContext().getRayletSocketName());
Assert.assertEquals(OBJECT_STORE_SOCKET_NAME,
Ray.getRuntimeContext().getObjectStoreSocketName());
return "ok";
}
}