diff --git a/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java index e2197cfd1..a91781ac9 100644 --- a/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/io/ray/runtime/config/RayConfig.java @@ -31,8 +31,6 @@ public class RayConfig { public static final String DEFAULT_CONFIG_FILE = "ray.default.conf"; public static final String CUSTOM_CONFIG_FILE = "ray.conf"; - private static int DEFAULT_NUM_JAVA_WORKER_PER_PROCESS = 10; - private static final Random RANDOM = new Random(); private static final DateTimeFormatter DATE_TIME_FORMATTER = @@ -97,7 +95,6 @@ public class RayConfig { } } - public final int numWorkersPerProcess; public final List jvmOptionsForJavaWorker; @@ -242,14 +239,13 @@ public class RayConfig { } if (!enableMultiTenancy) { - numWorkersPerProcess = config.getInt("ray.raylet.config.num_workers_per_process_java"); - } else { - final int localNumWorkersPerProcess = config.getInt("ray.job.num-java-workers-per-process"); - if (localNumWorkersPerProcess <= 0) { - numWorkersPerProcess = DEFAULT_NUM_JAVA_WORKER_PER_PROCESS; + if (!isDriver) { + numWorkersPerProcess = config.getInt("ray.raylet.config.num_workers_per_process_java"); } else { - numWorkersPerProcess = localNumWorkersPerProcess; + numWorkersPerProcess = 1; // Actually this value isn't used in RayNativeRuntime. } + } else { + numWorkersPerProcess = config.getInt("ray.job.num-java-workers-per-process"); } // Validate config. diff --git a/java/runtime/src/main/resources/ray.default.conf b/java/runtime/src/main/resources/ray.default.conf index b3fa59520..66002f48d 100644 --- a/java/runtime/src/main/resources/ray.default.conf +++ b/java/runtime/src/main/resources/ray.default.conf @@ -27,7 +27,7 @@ ray { // the path for job 123 will be '/tmp/job_resources/123'. resource-path: "" /// The number of java worker per worker process. - num-java-workers-per-process: 10 + num-java-workers-per-process: 1 /// The jvm options for java workers of the job. jvm-options: [] // Environment variables to be set on worker processes. @@ -98,7 +98,6 @@ ray { raylet { // See src/ray/ray_config_def.h for options. config { - num_workers_per_process_java: 10 // TODO(zhuohan): enable this for java put_small_object_in_memory_store: false } diff --git a/java/test.sh b/java/test.sh index 875522f5c..32105a158 100755 --- a/java/test.sh +++ b/java/test.sh @@ -33,15 +33,18 @@ bazel build //java:gen_maven_deps echo "Build test jar." bazel build //java:all_tests_deploy.jar +# Enable multi-worker feature in Java test +TEST_ARGS=(-Dray.raylet.config.num_workers_per_process_java=10 -Dray.job.num-java-workers-per-process=10) + echo "Running tests under cluster mode." # TODO(hchen): Ideally, we should use the following bazel command to run Java tests. However, if there're skipped tests, # TestNG will exit with code 2. And bazel treats it as test failure. # bazel test //java:all_tests --action_env=ENABLE_MULTI_LANGUAGE_TESTS=1 --config=ci || cluster_exit_code=$? -ENABLE_MULTI_LANGUAGE_TESTS=1 run_testng java -cp "$ROOT_DIR"/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output "$ROOT_DIR"/testng.xml +ENABLE_MULTI_LANGUAGE_TESTS=1 run_testng java -cp "$ROOT_DIR"/../bazel-bin/java/all_tests_deploy.jar "${TEST_ARGS[@]}" org.testng.TestNG -d /tmp/ray_java_test_output "$ROOT_DIR"/testng.xml echo "Running tests under single-process mode." # bazel test //java:all_tests --jvmopt="-Dray.run-mode=SINGLE_PROCESS" --config=ci || single_exit_code=$? -run_testng java -Dray.run-mode="SINGLE_PROCESS" -cp "$ROOT_DIR"/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output "$ROOT_DIR"/testng.xml +run_testng java -Dray.run-mode="SINGLE_PROCESS" -cp "$ROOT_DIR"/../bazel-bin/java/all_tests_deploy.jar "${TEST_ARGS[@]}" org.testng.TestNG -d /tmp/ray_java_test_output "$ROOT_DIR"/testng.xml popd diff --git a/java/test/src/main/java/io/ray/test/ClassLoaderTest.java b/java/test/src/main/java/io/ray/test/ClassLoaderTest.java index e5f43e2d0..b6b769f62 100644 --- a/java/test/src/main/java/io/ray/test/ClassLoaderTest.java +++ b/java/test/src/main/java/io/ray/test/ClassLoaderTest.java @@ -40,7 +40,7 @@ public class ClassLoaderTest extends BaseTest { @Test(groups = {"cluster"}) public void testClassLoaderInMultiThreading() throws Exception { - Assert.assertTrue(TestUtils.getRuntime().getRayConfig().numWorkersPerProcess > 1); + Assert.assertTrue(TestUtils.getNumWorkersPerProcess() > 1); final String jobResourcePath = resourcePath + "/" + Ray.getRuntimeContext().getCurrentJobId(); File jobResourceDir = new File(jobResourcePath); diff --git a/java/test/src/main/java/io/ray/test/ExitActorTest.java b/java/test/src/main/java/io/ray/test/ExitActorTest.java index 4ca537bc8..9c95cb960 100644 --- a/java/test/src/main/java/io/ray/test/ExitActorTest.java +++ b/java/test/src/main/java/io/ray/test/ExitActorTest.java @@ -73,7 +73,7 @@ public class ExitActorTest extends BaseTest { } public void testExitActorInMultiWorker() { - Assert.assertTrue(TestUtils.getRuntime().getRayConfig().numWorkersPerProcess > 1); + Assert.assertTrue(TestUtils.getNumWorkersPerProcess() > 1); ActorHandle actor1 = Ray.actor(ExitingActor::new) .setMaxRestarts(10000).remote(); int pid = actor1.task(ExitingActor::getPid).remote().get(); diff --git a/java/test/src/main/java/io/ray/test/FailureTest.java b/java/test/src/main/java/io/ray/test/FailureTest.java index 07eca0263..3f779bf23 100644 --- a/java/test/src/main/java/io/ray/test/FailureTest.java +++ b/java/test/src/main/java/io/ray/test/FailureTest.java @@ -23,17 +23,20 @@ public class FailureTest extends BaseTest { private static final String EXCEPTION_MESSAGE = "Oops"; + private String oldNumWorkersPerProcess; + @BeforeClass public void setUp() { // This is needed by `testGetThrowsQuicklyWhenFoundException`. // Set one worker per process. Otherwise, if `badFunc2` and `slowFunc` run in the same // process, `sleep` will delay `System.exit`. + oldNumWorkersPerProcess = System.getProperty("ray.raylet.config.num_workers_per_process_java"); System.setProperty("ray.raylet.config.num_workers_per_process_java", "1"); } @AfterClass public void tearDown() { - System.clearProperty("ray.raylet.config.num_workers_per_process_java"); + System.setProperty("ray.raylet.config.num_workers_per_process_java", oldNumWorkersPerProcess); } public static int badFunc() { diff --git a/java/test/src/main/java/io/ray/test/JobConfigTest.java b/java/test/src/main/java/io/ray/test/JobConfigTest.java index 5bb4ea48a..f7a7b79c3 100644 --- a/java/test/src/main/java/io/ray/test/JobConfigTest.java +++ b/java/test/src/main/java/io/ray/test/JobConfigTest.java @@ -1,7 +1,6 @@ package io.ray.test; import io.ray.api.ActorHandle; -import io.ray.api.ObjectRef; import io.ray.api.Ray; import org.testng.Assert; import org.testng.annotations.AfterClass; @@ -11,9 +10,12 @@ import org.testng.annotations.Test; @Test(groups = {"cluster"}) public class JobConfigTest extends BaseTest { + private String oldNumWorkersPerProcess; + @BeforeClass public void setupJobConfig() { System.setProperty("ray.raylet.config.enable_multi_tenancy", "true"); + oldNumWorkersPerProcess = System.getProperty("ray.job.num-java-workers-per-process"); System.setProperty("ray.job.num-java-workers-per-process", "3"); System.setProperty("ray.job.jvm-options.0", "-DX=999"); System.setProperty("ray.job.jvm-options.1", "-DY=998"); @@ -24,7 +26,7 @@ public class JobConfigTest extends BaseTest { @AfterClass public void tearDownJobConfig() { System.clearProperty("ray.raylet.config.enable_multi_tenancy"); - System.clearProperty("ray.job.num-java-workers-per-process"); + System.setProperty("ray.job.num-java-workers-per-process", oldNumWorkersPerProcess); System.clearProperty("ray.job.jvm-options.0"); System.clearProperty("ray.job.jvm-options.1"); System.clearProperty("ray.job.worker-env.foo1"); @@ -39,16 +41,8 @@ public class JobConfigTest extends BaseTest { return System.getenv(key); } - public static Integer getWorkersNum() { - return TestUtils.getRuntime().getRayConfig().numWorkersPerProcess; - } - public static class MyActor { - public Integer getWorkersNum() { - return TestUtils.getRuntime().getRayConfig().numWorkersPerProcess; - } - public String getJvmOptions(String propertyName) { return System.getProperty(propertyName); } @@ -68,9 +62,8 @@ public class JobConfigTest extends BaseTest { Assert.assertEquals("bar2", Ray.task(JobConfigTest::getEnvVariable, "foo2").remote().get()); } - public void testNumJavaWorkerPerProcess() { - ObjectRef obj = Ray.task(JobConfigTest::getWorkersNum).remote(); - Assert.assertEquals(3, (int) obj.get()); + public void testNumJavaWorkersPerProcess() { + Assert.assertEquals(TestUtils.getNumWorkersPerProcess(), 3); } @@ -84,9 +77,5 @@ public class JobConfigTest extends BaseTest { // test worker env variables Assert.assertEquals("bar1", Ray.task(MyActor::getEnvVariable, "foo1").remote().get()); Assert.assertEquals("bar2", Ray.task(MyActor::getEnvVariable, "foo2").remote().get()); - - // test workers number. - ObjectRef obj2 = actor.task(MyActor::getWorkersNum).remote(); - Assert.assertEquals(3, (int) obj2.get()); } } diff --git a/java/test/src/main/java/io/ray/test/KillActorTest.java b/java/test/src/main/java/io/ray/test/KillActorTest.java index c8ad3b8ed..cf88896d8 100644 --- a/java/test/src/main/java/io/ray/test/KillActorTest.java +++ b/java/test/src/main/java/io/ray/test/KillActorTest.java @@ -14,14 +14,17 @@ import org.testng.annotations.Test; @Test(groups = {"cluster"}) public class KillActorTest extends BaseTest { + private String oldNumWorkersPerProcess; + @BeforeClass public void setUp() { + oldNumWorkersPerProcess = System.getProperty("ray.raylet.config.num_workers_per_process_java"); System.setProperty("ray.raylet.config.num_workers_per_process_java", "1"); } @AfterClass public void tearDown() { - System.clearProperty("ray.raylet.config.num_workers_per_process_java"); + System.setProperty("ray.raylet.config.num_workers_per_process_java", oldNumWorkersPerProcess); } public static class HangActor { diff --git a/java/test/src/main/java/io/ray/test/MultiThreadingTest.java b/java/test/src/main/java/io/ray/test/MultiThreadingTest.java index bbd360637..f9371748d 100644 --- a/java/test/src/main/java/io/ray/test/MultiThreadingTest.java +++ b/java/test/src/main/java/io/ray/test/MultiThreadingTest.java @@ -19,7 +19,7 @@ import org.slf4j.LoggerFactory; import org.testng.Assert; import org.testng.annotations.Test; -@Test +@Test(groups = {"cluster"}) public class MultiThreadingTest extends BaseTest { private static final Logger LOGGER = LoggerFactory.getLogger(MultiThreadingTest.class); @@ -221,11 +221,6 @@ public class MultiThreadingTest extends BaseTest { return true; } - @Test - public void testMissingWrapRunnableInDriver() throws InterruptedException { - testMissingWrapRunnable(); - } - @Test public void testMissingWrapRunnableInWorker() { Ray.task(MultiThreadingTest::testMissingWrapRunnable).remote().get(); diff --git a/java/test/src/main/java/io/ray/test/TestUtils.java b/java/test/src/main/java/io/ray/test/TestUtils.java index 88f4d25cd..8075a40c2 100644 --- a/java/test/src/main/java/io/ray/test/TestUtils.java +++ b/java/test/src/main/java/io/ray/test/TestUtils.java @@ -3,6 +3,7 @@ package io.ray.test; import com.google.common.base.Preconditions; import io.ray.api.ObjectRef; import io.ray.api.Ray; +import io.ray.runtime.AbstractRayRuntime; import io.ray.runtime.RayRuntimeInternal; import io.ray.runtime.RayRuntimeProxy; import io.ray.runtime.config.RunMode; @@ -83,8 +84,19 @@ public class TestUtils { } public static RayRuntimeInternal getUnderlyingRuntime() { + if (Ray.internal() instanceof AbstractRayRuntime) { + return (RayRuntimeInternal) Ray.internal(); + } RayRuntimeProxy proxy = (RayRuntimeProxy) (java.lang.reflect.Proxy .getInvocationHandler(Ray.internal())); return proxy.getRuntimeObject(); } + + private static int getNumWorkersPerProcessRemoteFunction() { + return TestUtils.getRuntime().getRayConfig().numWorkersPerProcess; + } + + public static int getNumWorkersPerProcess() { + return Ray.task(TestUtils::getNumWorkersPerProcessRemoteFunction).remote().get(); + } } diff --git a/python/ray/job_config.py b/python/ray/job_config.py index b82160d99..eedee46ef 100644 --- a/python/ray/job_config.py +++ b/python/ray/job_config.py @@ -15,7 +15,7 @@ class JobConfig: def __init__( self, worker_env=None, - num_java_workers_per_process=10, + num_java_workers_per_process=1, jvm_options=None, ): if worker_env is None: diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index 9b99dc9b7..73e9e0ded 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -197,7 +197,7 @@ RAY_CONFIG(uint64_t, object_manager_default_chunk_size, 1000000) RAY_CONFIG(int, num_workers_per_process_python, 1) /// Number of workers per Java worker process -RAY_CONFIG(int, num_workers_per_process_java, 10) +RAY_CONFIG(int, num_workers_per_process_java, 1) /// Number of workers per CPP worker process RAY_CONFIG(int, num_workers_per_process_cpp, 1)