[Java] Support multiple workers in Java worker process (#5505)

This commit is contained in:
Kai Yang
2019-09-07 22:52:05 +08:00
committed by Hao Chen
parent d89ceb3ee5
commit 732336fc4f
37 changed files with 512 additions and 148 deletions
@@ -1,8 +1,11 @@
package org.ray.api;
import com.google.common.base.Preconditions;
import java.util.function.Supplier;
import org.ray.api.annotation.RayRemote;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.config.RunMode;
import org.testng.Assert;
import org.testng.SkipException;
@@ -12,8 +15,7 @@ public class TestUtils {
private static final int WAIT_INTERVAL_MS = 5;
public static void skipTestUnderSingleProcess() {
AbstractRayRuntime runtime = (AbstractRayRuntime)Ray.internal();
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
if (getRuntime().getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
throw new SkipException("This test doesn't work under single-process mode.");
}
}
@@ -62,4 +64,13 @@ public class TestUtils {
RayObject<String> obj = Ray.call(TestUtils::hi);
Assert.assertEquals(obj.get(), "hi");
}
public static AbstractRayRuntime getRuntime() {
RayRuntime runtime = Ray.internal();
if (runtime instanceof RayMultiWorkerNativeRuntime) {
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
}
Preconditions.checkState(runtime instanceof AbstractRayRuntime);
return (AbstractRayRuntime) runtime;
}
}
@@ -11,7 +11,6 @@ import org.ray.api.TestUtils;
import org.ray.api.annotation.RayRemote;
import org.ray.api.exception.UnreconstructableException;
import org.ray.api.id.UniqueId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.actor.NativeRayActor;
import org.ray.runtime.object.NativeRayObject;
import org.testng.Assert;
@@ -117,7 +116,7 @@ public class ActorTest extends BaseTest {
Ray.internal().free(ImmutableList.of(value.getId()), false, false);
// Wait until the object is deleted, because the above free operation is async.
while (true) {
NativeRayObject result = ((AbstractRayRuntime) Ray.internal()).getObjectStore()
NativeRayObject result = TestUtils.getRuntime().getObjectStore()
.getRaw(ImmutableList.of(value.getId()), 0).get(0);
if (result == null) {
break;
@@ -1,5 +1,6 @@
package org.ray.api.test;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.io.File;
@@ -63,7 +64,7 @@ public abstract class BaseMultiLanguageTest {
// Start ray cluster.
String workerOptions =
" -classpath " + System.getProperty("java.class.path");
final List<String> startCommand = ImmutableList.of(
List<String> startCommand = ImmutableList.of(
"ray",
"start",
"--head",
@@ -74,6 +75,13 @@ public abstract class BaseMultiLanguageTest {
"--include-java",
"--java-worker-options=" + workerOptions
);
String numWorkersPerProcessJava = System
.getProperty("ray.raylet.config.num_workers_per_process_java");
if (!Strings.isNullOrEmpty(numWorkersPerProcessJava)) {
startCommand = ImmutableList.<String>builder().addAll(startCommand)
.add(String.format("--internal-config={\"num_workers_per_process_java\": %s}",
numWorkersPerProcessJava)).build();
}
if (!executeCommand(startCommand, 10, getRayStartEnv())) {
throw new RuntimeException("Couldn't start ray cluster.");
}
@@ -127,6 +127,7 @@ public class FailureTest extends BaseTest {
TestUtils.skipTestUnderSingleProcess();
List<RayFunc0<Integer>> badFunctions = Arrays.asList(FailureTest::badFunc,
FailureTest::badFunc2);
TestUtils.warmUpCluster();
for (RayFunc0<Integer> badFunc : badFunctions) {
RayObject<Integer> obj1 = Ray.call(badFunc);
RayObject<Integer> obj2 = Ray.call(FailureTest::slowFunc);
@@ -2,11 +2,9 @@ package org.ray.api.test;
import com.google.common.base.Preconditions;
import java.util.List;
import org.ray.api.Ray;
import org.ray.api.TestUtils;
import org.ray.api.id.JobId;
import org.ray.api.runtimecontext.NodeInfo;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.gcs.GcsClient;
import org.testng.Assert;
@@ -29,10 +27,10 @@ public class GcsClientTest extends BaseTest {
@Test
public void testGetAllNodeInfo() {
TestUtils.skipTestUnderSingleProcess();
RayConfig config = ((AbstractRayRuntime)Ray.internal()).getRayConfig();
RayConfig config = TestUtils.getRuntime().getRayConfig();
Preconditions.checkNotNull(config);
GcsClient gcsClient = ((AbstractRayRuntime)Ray.internal()).getGcsClient();
GcsClient gcsClient = TestUtils.getRuntime().getGcsClient();
List<NodeInfo> allNodeInfo = gcsClient.getAllNodeInfo();
Assert.assertEquals(allNodeInfo.size(), 1);
Assert.assertEquals(allNodeInfo.get(0).nodeAddress, config.nodeIp);
@@ -43,11 +41,11 @@ public class GcsClientTest extends BaseTest {
@Test
public void testNextJob() {
TestUtils.skipTestUnderSingleProcess();
RayConfig config = ((AbstractRayRuntime)Ray.internal()).getRayConfig();
RayConfig config = TestUtils.getRuntime().getRayConfig();
// The value of job id of this driver in cluster should be `1L`.
Assert.assertEquals(config.getJobId(), JobId.fromInt(1));
GcsClient gcsClient = ((AbstractRayRuntime)Ray.internal()).getGcsClient();
GcsClient gcsClient = TestUtils.getRuntime().getGcsClient();
for (int i = 2; i < 100; ++i) {
Assert.assertEquals(gcsClient.nextJobId(), JobId.fromInt(i));
}
@@ -54,19 +54,22 @@ public class MultiThreadingTest extends BaseTest {
}
@RayRemote
public ActorId getCurrentActorId() {
final ActorId[] result = new ActorId[1];
Thread thread = new Thread(() -> {
result[0] = Ray.getRuntimeContext().getCurrentActorId();
});
public ActorId getCurrentActorId() throws Exception {
final Object[] result = new Object[1];
Thread thread = new Thread(Ray.wrapRunnable(() -> {
try {
result[0] = Ray.getRuntimeContext().getCurrentActorId();
} catch (Exception e) {
result[0] = e;
}
}));
thread.start();
try {
thread.join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
thread.join();
if (result[0] instanceof Exception) {
throw (Exception) result[0];
}
Assert.assertEquals(result[0], actorId);
return result[0];
return (ActorId) result[0];
}
}
@@ -147,13 +150,13 @@ public class MultiThreadingTest extends BaseTest {
try {
List<Future<String>> futures = new ArrayList<>();
for (int i = 0; i < NUM_THREADS; i++) {
Callable<String> task = () -> {
Callable<String> task = Ray.wrapCallable(() -> {
for (int j = 0; j < numRepeats; j++) {
TimeUnit.MILLISECONDS.sleep(1);
testCase.run();
}
return "ok";
};
});
futures.add(service.submit(task));
}
for (Future<String> future : futures) {
@@ -7,7 +7,6 @@ import org.ray.api.RayObject;
import org.ray.api.TestUtils;
import org.ray.api.annotation.RayRemote;
import org.ray.api.id.TaskId;
import org.ray.runtime.AbstractRayRuntime;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -26,7 +25,7 @@ public class PlasmaFreeTest extends BaseTest {
Ray.internal().free(ImmutableList.of(helloId.getId()), true, false);
final boolean result = TestUtils.waitForCondition(() ->
((AbstractRayRuntime) Ray.internal()).getObjectStore()
TestUtils.getRuntime().getObjectStore()
.getRaw(ImmutableList.of(helloId.getId()), 0).get(0) == null, 50);
Assert.assertTrue(result);
}
@@ -40,7 +39,7 @@ public class PlasmaFreeTest extends BaseTest {
TaskId taskId = TaskId.fromBytes(Arrays.copyOf(helloId.getId().getBytes(), TaskId.LENGTH));
final boolean result = TestUtils.waitForCondition(
() -> !(((AbstractRayRuntime) Ray.internal()).getGcsClient())
() -> !TestUtils.getRuntime().getGcsClient()
.rayletTaskExistsInGcs(taskId), 50);
Assert.assertTrue(result);
}
@@ -1,11 +1,8 @@
package org.ray.api.test;
import java.util.Collections;
import org.ray.api.Ray;
import org.ray.api.TestUtils;
import org.ray.api.id.ObjectId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.object.ObjectStore;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -16,8 +13,7 @@ public class PlasmaStoreTest extends BaseTest {
public void testPutWithDuplicateId() {
TestUtils.skipTestUnderSingleProcess();
ObjectId objectId = ObjectId.fromRandom();
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
ObjectStore objectStore = runtime.getObjectStore();
ObjectStore objectStore = TestUtils.getRuntime().getObjectStore();
objectStore.put("1", objectId);
Assert.assertEquals(Ray.get(objectId), "1");
objectStore.put("2", objectId);
@@ -6,9 +6,9 @@ import java.io.Serializable;
import java.util.List;
import java.util.Map;
import org.ray.api.Ray;
import org.ray.api.TestUtils;
import org.ray.api.annotation.RayRemote;
import org.ray.api.id.ObjectId;
import org.ray.runtime.AbstractRayRuntime;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -80,7 +80,7 @@ public class RayCallTest extends BaseTest {
@RayRemote
private static void testNoReturn(ObjectId objectId) {
// Put an object in object store to inform driver that this function is executing.
((AbstractRayRuntime) Ray.internal()).getObjectStore().put(1, objectId);
TestUtils.getRuntime().getObjectStore().put(1, objectId);
}
/**
@@ -2,8 +2,8 @@ package org.ray.api.test;
import org.ray.api.Ray;
import org.ray.api.RayPyActor;
import org.ray.api.TestUtils;
import org.ray.api.id.ObjectId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.object.ObjectStore;
import org.testng.Assert;
@@ -14,7 +14,7 @@ public class RaySerializerTest extends BaseMultiLanguageTest {
@Test
public void testSerializePyActor() {
RayPyActor pyActor = Ray.createPyActor("test", "RaySerializerTest");
ObjectStore objectStore = ((AbstractRayRuntime) Ray.internal()).getObjectStore();
ObjectStore objectStore = TestUtils.getRuntime().getObjectStore();
NativeRayObject nativeRayObject = objectStore.serialize(pyActor);
RayPyActor result = (RayPyActor) objectStore
.deserialize(nativeRayObject, ObjectId.fromRandom());