diff --git a/java/pom.xml b/java/pom.xml
index d1c7aaee3..a7b9158b7 100644
--- a/java/pom.xml
+++ b/java/pom.xml
@@ -111,9 +111,9 @@
1.10.19
- junit
- junit
- 4.11
+ org.testng
+ testng
+ 6.9.9
diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml
index bc40ac776..b5578404d 100644
--- a/java/runtime/pom.xml
+++ b/java/runtime/pom.xml
@@ -72,8 +72,8 @@
- junit
- junit
+ org.testng
+ testng
test
diff --git a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java
index f5ff1e481..e0307635a 100644
--- a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java
+++ b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java
@@ -7,15 +7,14 @@ import java.nio.file.StandardCopyOption;
import java.util.Map;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.BeforeClass;
-import org.junit.Test;
import org.ray.api.annotation.RayRemote;
import org.ray.api.function.RayFunc0;
import org.ray.api.function.RayFunc1;
import org.ray.api.id.UniqueId;
import org.ray.runtime.functionmanager.FunctionManager.DriverFunctionTable;
+import org.testng.Assert;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
/**
* Tests for {@link FunctionManager}
diff --git a/java/test/pom.xml b/java/test/pom.xml
index d2031537f..448364641 100644
--- a/java/test/pom.xml
+++ b/java/test/pom.xml
@@ -30,8 +30,8 @@
- junit
- junit
+ org.testng
+ testng
diff --git a/java/test/src/main/java/org/ray/api/benchmark/ActorPressTest.java b/java/test/src/main/java/org/ray/api/benchmark/ActorPressTest.java
index 9945e1777..5e0c28c3c 100644
--- a/java/test/src/main/java/org/ray/api/benchmark/ActorPressTest.java
+++ b/java/test/src/main/java/org/ray/api/benchmark/ActorPressTest.java
@@ -1,10 +1,10 @@
package org.ray.api.benchmark;
-import org.junit.Test;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.annotation.RayRemote;
+import org.testng.annotations.Test;
public class ActorPressTest extends RayBenchmarkTest {
diff --git a/java/test/src/main/java/org/ray/api/benchmark/MaxPressureTest.java b/java/test/src/main/java/org/ray/api/benchmark/MaxPressureTest.java
index bf5d9c5ac..3db8737bc 100644
--- a/java/test/src/main/java/org/ray/api/benchmark/MaxPressureTest.java
+++ b/java/test/src/main/java/org/ray/api/benchmark/MaxPressureTest.java
@@ -1,10 +1,10 @@
package org.ray.api.benchmark;
-import org.junit.Test;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.annotation.RayRemote;
+import org.testng.annotations.Test;
public class MaxPressureTest extends RayBenchmarkTest {
diff --git a/java/test/src/main/java/org/ray/api/benchmark/RateLimiterPressureTest.java b/java/test/src/main/java/org/ray/api/benchmark/RateLimiterPressureTest.java
index 4c44a0332..d379f6575 100644
--- a/java/test/src/main/java/org/ray/api/benchmark/RateLimiterPressureTest.java
+++ b/java/test/src/main/java/org/ray/api/benchmark/RateLimiterPressureTest.java
@@ -1,10 +1,10 @@
package org.ray.api.benchmark;
-import org.junit.Test;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.annotation.RayRemote;
+import org.testng.annotations.Test;
public class RateLimiterPressureTest extends RayBenchmarkTest {
diff --git a/java/test/src/main/java/org/ray/api/benchmark/RayBenchmarkTest.java b/java/test/src/main/java/org/ray/api/benchmark/RayBenchmarkTest.java
index 8e96b78fa..4ea0b5fa8 100644
--- a/java/test/src/main/java/org/ray/api/benchmark/RayBenchmarkTest.java
+++ b/java/test/src/main/java/org/ray/api/benchmark/RayBenchmarkTest.java
@@ -6,7 +6,6 @@ import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
-import org.junit.Assert;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
@@ -14,6 +13,7 @@ import org.ray.api.annotation.RayRemote;
import org.ray.api.test.BaseTest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.testng.Assert;
public abstract class RayBenchmarkTest extends BaseTest implements Serializable {
diff --git a/java/test/src/main/java/org/ray/api/benchmark/SingleLatencyTest.java b/java/test/src/main/java/org/ray/api/benchmark/SingleLatencyTest.java
index 5e82163bc..3c4250614 100644
--- a/java/test/src/main/java/org/ray/api/benchmark/SingleLatencyTest.java
+++ b/java/test/src/main/java/org/ray/api/benchmark/SingleLatencyTest.java
@@ -1,10 +1,10 @@
package org.ray.api.benchmark;
-import org.junit.Test;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.annotation.RayRemote;
+import org.testng.annotations.Test;
public class SingleLatencyTest extends RayBenchmarkTest {
diff --git a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java
index 59bba919f..a0fdbfbf0 100644
--- a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java
+++ b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java
@@ -5,12 +5,12 @@ import static org.ray.runtime.util.SystemUtil.pid;
import java.io.IOException;
import java.util.HashMap;
import java.util.concurrent.TimeUnit;
-import org.junit.Assert;
-import org.junit.Test;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.annotation.RayRemote;
import org.ray.api.options.ActorCreationOptions;
+import org.testng.Assert;
+import org.testng.annotations.Test;
public class ActorReconstructionTest extends BaseTest {
diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java
index b7b839a79..2c47d84c3 100644
--- a/java/test/src/main/java/org/ray/api/test/ActorTest.java
+++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java
@@ -1,7 +1,5 @@
package org.ray.api.test;
-import org.junit.Assert;
-import org.junit.Test;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
@@ -9,6 +7,8 @@ import org.ray.api.annotation.RayRemote;
import org.ray.api.function.RayFunc2;
import org.ray.api.id.UniqueId;
import org.ray.runtime.RayActorImpl;
+import org.testng.Assert;
+import org.testng.annotations.Test;
public class ActorTest extends BaseTest {
diff --git a/java/test/src/main/java/org/ray/api/test/BaseTest.java b/java/test/src/main/java/org/ray/api/test/BaseTest.java
index f7bcf01b4..23f893a46 100644
--- a/java/test/src/main/java/org/ray/api/test/BaseTest.java
+++ b/java/test/src/main/java/org/ray/api/test/BaseTest.java
@@ -1,20 +1,22 @@
package org.ray.api.test;
import java.io.File;
-import org.junit.After;
-import org.junit.Before;
import org.ray.api.Ray;
+import org.testng.annotations.AfterMethod;
+import org.testng.annotations.AfterTest;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.BeforeTest;
public class BaseTest {
- @Before
+ @BeforeMethod
public void setUp() {
System.setProperty("ray.home", "../..");
System.setProperty("ray.resources", "CPU:4,RES-A:4");
Ray.init();
}
- @After
+ @AfterMethod
public void tearDown() {
// TODO(qwang): This is double check to check that the socket file is removed actually.
// We could not enable this until `systemInfo` enabled.
diff --git a/java/test/src/main/java/org/ray/api/test/FailureTest.java b/java/test/src/main/java/org/ray/api/test/FailureTest.java
index 0e9d6da91..3c1fa94d8 100644
--- a/java/test/src/main/java/org/ray/api/test/FailureTest.java
+++ b/java/test/src/main/java/org/ray/api/test/FailureTest.java
@@ -1,11 +1,11 @@
package org.ray.api.test;
-import org.junit.Assert;
-import org.junit.Test;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.exception.RayException;
+import org.testng.Assert;
+import org.testng.annotations.Test;
public class FailureTest extends BaseTest {
diff --git a/java/test/src/main/java/org/ray/api/test/HelloWorldTest.java b/java/test/src/main/java/org/ray/api/test/HelloWorldTest.java
index 9f31363e8..feb07fe2c 100644
--- a/java/test/src/main/java/org/ray/api/test/HelloWorldTest.java
+++ b/java/test/src/main/java/org/ray/api/test/HelloWorldTest.java
@@ -1,10 +1,10 @@
package org.ray.api.test;
-import org.junit.Assert;
-import org.junit.Test;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.annotation.RayRemote;
+import org.testng.Assert;
+import org.testng.annotations.Test;
/**
* Hello world.
diff --git a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java
index c95e7093c..6bbd39ffa 100644
--- a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java
+++ b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java
@@ -9,13 +9,13 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
-import org.junit.Assert;
-import org.junit.Test;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.annotation.RayRemote;
+import org.testng.Assert;
+import org.testng.annotations.Test;
public class MultiThreadingTest extends BaseTest {
diff --git a/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java b/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java
index d38c46992..eaa99a289 100644
--- a/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java
+++ b/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java
@@ -3,12 +3,11 @@ package org.ray.api.test;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Collectors;
-import org.junit.Assert;
-import org.junit.Test;
-import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.id.UniqueId;
+import org.testng.Assert;
+import org.testng.annotations.Test;
/**
* Test putting and getting objects.
diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java
index 795e0efdb..736150b8d 100644
--- a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java
+++ b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java
@@ -3,14 +3,13 @@ package org.ray.api.test;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
-import org.junit.Assert;
-import org.junit.Test;
-import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.annotation.RayRemote;
import org.ray.api.id.UniqueId;
+import org.testng.Assert;
+import org.testng.annotations.Test;
public class PlasmaFreeTest extends BaseTest {
diff --git a/java/test/src/main/java/org/ray/api/test/RayCallTest.java b/java/test/src/main/java/org/ray/api/test/RayCallTest.java
index 08e90e589..c97e3fe91 100644
--- a/java/test/src/main/java/org/ray/api/test/RayCallTest.java
+++ b/java/test/src/main/java/org/ray/api/test/RayCallTest.java
@@ -2,15 +2,13 @@ package org.ray.api.test;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
-
import java.io.Serializable;
import java.util.List;
import java.util.Map;
-import org.junit.Assert;
-import org.junit.Test;
-import org.junit.runner.RunWith;
import org.ray.api.Ray;
import org.ray.api.annotation.RayRemote;
+import org.testng.Assert;
+import org.testng.annotations.Test;
/**
* Test Ray.call API
@@ -87,7 +85,7 @@ public class RayCallTest extends BaseTest {
Assert.assertEquals(1, (long) Ray.call(RayCallTest::testLong, 1L).get());
Assert.assertEquals(1.0, Ray.call(RayCallTest::testDouble, 1.0).get(), 0.0);
Assert.assertEquals(1.0f, Ray.call(RayCallTest::testFloat, 1.0f).get(), 0.0);
- Assert.assertEquals(true, Ray.call(RayCallTest::testBool, true).get());
+ Assert.assertTrue(Ray.call(RayCallTest::testBool, true).get());
Assert.assertEquals("foo", Ray.call(RayCallTest::testString, "foo").get());
List list = ImmutableList.of(1, 2, 3);
Assert.assertEquals(list, Ray.call(RayCallTest::testList, list).get());
diff --git a/java/test/src/main/java/org/ray/api/test/RayConfigTest.java b/java/test/src/main/java/org/ray/api/test/RayConfigTest.java
index 8260b39d4..b7b655bd3 100644
--- a/java/test/src/main/java/org/ray/api/test/RayConfigTest.java
+++ b/java/test/src/main/java/org/ray/api/test/RayConfigTest.java
@@ -1,10 +1,10 @@
package org.ray.api.test;
-import org.junit.Assert;
-import org.junit.Test;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.config.WorkerMode;
+import org.testng.Assert;
+import org.testng.annotations.Test;
public class RayConfigTest {
diff --git a/java/test/src/main/java/org/ray/api/test/RayMethodsTest.java b/java/test/src/main/java/org/ray/api/test/RayMethodsTest.java
index 8fb430bdc..784617fc1 100644
--- a/java/test/src/main/java/org/ray/api/test/RayMethodsTest.java
+++ b/java/test/src/main/java/org/ray/api/test/RayMethodsTest.java
@@ -3,11 +3,11 @@ package org.ray.api.test;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Collectors;
-import org.junit.Assert;
-import org.junit.Test;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
+import org.testng.Assert;
+import org.testng.annotations.Test;
/**
* Integration test for Ray.*
diff --git a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java
index 36abda2f3..15bd84d1e 100644
--- a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java
+++ b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java
@@ -2,8 +2,6 @@ package org.ray.api.test;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
-import org.junit.Assert;
-import org.junit.Test;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
@@ -11,6 +9,8 @@ import org.ray.api.WaitResult;
import org.ray.api.annotation.RayRemote;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.CallOptions;
+import org.testng.Assert;
+import org.testng.annotations.Test;
/**
* Resources Management Test.
diff --git a/java/test/src/main/java/org/ray/api/test/StressTest.java b/java/test/src/main/java/org/ray/api/test/StressTest.java
index a85a8ca2b..652bbaf4e 100644
--- a/java/test/src/main/java/org/ray/api/test/StressTest.java
+++ b/java/test/src/main/java/org/ray/api/test/StressTest.java
@@ -3,12 +3,12 @@ package org.ray.api.test;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
-import org.junit.Assert;
-import org.junit.Test;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.id.UniqueId;
+import org.testng.Assert;
+import org.testng.annotations.Test;
public class StressTest extends BaseTest {
diff --git a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java
index 61443ed18..5607e81cd 100644
--- a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java
+++ b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java
@@ -2,12 +2,11 @@ package org.ray.api.test;
import java.nio.ByteBuffer;
import java.util.Arrays;
-import java.util.List;
import javax.xml.bind.DatatypeConverter;
-import org.junit.Assert;
-import org.junit.Test;
import org.ray.api.id.UniqueId;
import org.ray.runtime.util.UniqueIdUtil;
+import org.testng.Assert;
+import org.testng.annotations.Test;
public class UniqueIdTest {
diff --git a/java/test/src/main/java/org/ray/api/test/WaitTest.java b/java/test/src/main/java/org/ray/api/test/WaitTest.java
index 49b5f8365..e82b99d36 100644
--- a/java/test/src/main/java/org/ray/api/test/WaitTest.java
+++ b/java/test/src/main/java/org/ray/api/test/WaitTest.java
@@ -3,12 +3,12 @@ package org.ray.api.test;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
-import org.junit.Assert;
-import org.junit.Test;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.annotation.RayRemote;
+import org.testng.Assert;
+import org.testng.annotations.Test;
public class WaitTest extends BaseTest {
diff --git a/java/test/src/main/java/org/ray/api/test/WordCountTest.java b/java/test/src/main/java/org/ray/api/test/WordCountTest.java
index 2c019b8b2..43a1ae7cd 100644
--- a/java/test/src/main/java/org/ray/api/test/WordCountTest.java
+++ b/java/test/src/main/java/org/ray/api/test/WordCountTest.java
@@ -3,11 +3,12 @@ package org.ray.api.test;
import java.io.FileNotFoundException;
import java.util.Arrays;
import java.util.List;
-import org.junit.Assert;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.annotation.RayRemote;
import org.ray.runtime.util.FileUtil;
+import org.testng.Assert;
+import org.testng.annotations.Test;
/**
* given a directory of document files on each "machine", we would like to count the appearance of