mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[Java] Support multiple workers in Java worker process (#5505)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package org.ray.api;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Callable;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
@@ -117,6 +118,28 @@ public final class Ray extends RayCall {
|
||||
return runtime.wait(waitList, waitList.size(), Integer.MAX_VALUE);
|
||||
}
|
||||
|
||||
/**
|
||||
* If users want to use Ray API in there own threads, they should wrap their {@link Runnable}
|
||||
* objects with this method.
|
||||
*
|
||||
* @param runnable The runnable to wrap.
|
||||
* @return The wrapped runnable.
|
||||
*/
|
||||
public static Runnable wrapRunnable(Runnable runnable) {
|
||||
return runtime.wrapRunnable(runnable);
|
||||
}
|
||||
|
||||
/**
|
||||
* If users want to use Ray API in there own threads, they should wrap their {@link Callable}
|
||||
* objects with this method.
|
||||
*
|
||||
* @param callable The callable to wrap.
|
||||
* @return The wrapped callable.
|
||||
*/
|
||||
public static Callable wrapCallable(Callable callable) {
|
||||
return runtime.wrapCallable(callable);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the underlying runtime instance.
|
||||
*/
|
||||
|
||||
@@ -30,7 +30,7 @@ public class ActorCreationOptions extends BaseTaskOptions {
|
||||
|
||||
private Map<String, Double> resources = new HashMap<>();
|
||||
private int maxReconstructions = NO_RECONSTRUCTION;
|
||||
private String jvmOptions = "";
|
||||
private String jvmOptions = null;
|
||||
|
||||
public Builder setResources(Map<String, Double> resources) {
|
||||
this.resources = resources;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package org.ray.api.runtime;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Callable;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.RayPyActor;
|
||||
@@ -141,4 +142,18 @@ public interface RayRuntime {
|
||||
*/
|
||||
RayPyActor createPyActor(String moduleName, String className, Object[] args,
|
||||
ActorCreationOptions options);
|
||||
|
||||
/**
|
||||
* Wrap a {@link Runnable} with necessary context capture.
|
||||
* @param runnable The runnable to wrap.
|
||||
* @return The wrapped runnable.
|
||||
*/
|
||||
Runnable wrapRunnable(Runnable runnable);
|
||||
|
||||
/**
|
||||
* Wrap a {@link Callable} with necessary context capture.
|
||||
* @param callable The callable to wrap.
|
||||
* @return The wrapped callable.
|
||||
*/
|
||||
Callable wrapCallable(Callable callable);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Callable;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.RayPyActor;
|
||||
@@ -59,11 +60,6 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
runtimeContext = new RuntimeContextImpl(this);
|
||||
}
|
||||
|
||||
/**
|
||||
* Start runtime.
|
||||
*/
|
||||
public abstract void start() throws Exception;
|
||||
|
||||
@Override
|
||||
public abstract void shutdown();
|
||||
|
||||
@@ -168,6 +164,16 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
return (RayPyActor) createActorImpl(functionDescriptor, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Runnable wrapRunnable(Runnable runnable) {
|
||||
return runnable;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Callable wrapCallable(Callable callable) {
|
||||
return callable;
|
||||
}
|
||||
|
||||
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
|
||||
Object[] args, int numReturns, CallOptions options) {
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder
|
||||
|
||||
@@ -4,11 +4,12 @@ import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.api.runtime.RayRuntimeFactory;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* The default Ray runtime factory. It produces an instance of AbstractRayRuntime.
|
||||
* The default Ray runtime factory. It produces an instance of RayRuntime.
|
||||
*/
|
||||
public class DefaultRayRuntimeFactory implements RayRuntimeFactory {
|
||||
|
||||
@@ -18,14 +19,16 @@ public class DefaultRayRuntimeFactory implements RayRuntimeFactory {
|
||||
public RayRuntime createRayRuntime() {
|
||||
RayConfig rayConfig = RayConfig.create();
|
||||
try {
|
||||
AbstractRayRuntime runtime;
|
||||
RayRuntime runtime;
|
||||
if (rayConfig.runMode == RunMode.SINGLE_PROCESS) {
|
||||
runtime = new RayDevRuntime(rayConfig);
|
||||
} else {
|
||||
runtime = new RayNativeRuntime(rayConfig);
|
||||
if (rayConfig.workerMode == WorkerType.DRIVER) {
|
||||
runtime = new RayNativeRuntime(rayConfig);
|
||||
} else {
|
||||
runtime = new RayMultiWorkerNativeRuntime(rayConfig);
|
||||
}
|
||||
}
|
||||
|
||||
runtime.start();
|
||||
return runtime;
|
||||
} catch (Exception e) {
|
||||
LOGGER.error("Failed to initialize ray runtime", e);
|
||||
|
||||
@@ -11,14 +11,10 @@ import org.ray.runtime.task.TaskExecutor;
|
||||
|
||||
public class RayDevRuntime extends AbstractRayRuntime {
|
||||
|
||||
public RayDevRuntime(RayConfig rayConfig) {
|
||||
super(rayConfig);
|
||||
}
|
||||
|
||||
private AtomicInteger jobCounter = new AtomicInteger(0);
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
public RayDevRuntime(RayConfig rayConfig) {
|
||||
super(rayConfig);
|
||||
if (rayConfig.getJobId().isNil()) {
|
||||
rayConfig.setJobId(nextJobId());
|
||||
}
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Callable;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.function.RayFunc;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.api.runtimecontext.RuntimeContext;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* This is a proxy runtime for multi-worker support. It holds multiple {@link RayNativeRuntime}
|
||||
* instances and redirect calls to the correct one based on thread context.
|
||||
*/
|
||||
public class RayMultiWorkerNativeRuntime implements RayRuntime {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(RayMultiWorkerNativeRuntime.class);
|
||||
|
||||
/**
|
||||
* The number of workers per worker process.
|
||||
*/
|
||||
private final int numWorkers;
|
||||
/**
|
||||
* The worker threads.
|
||||
*/
|
||||
private final Thread[] threads;
|
||||
/**
|
||||
* The {@link RayNativeRuntime} instances of workers.
|
||||
*/
|
||||
private final RayNativeRuntime[] runtimes;
|
||||
/**
|
||||
* The {@link RayNativeRuntime} instance of current thread.
|
||||
*/
|
||||
private final ThreadLocal<RayNativeRuntime> currentThreadRuntime = new ThreadLocal<>();
|
||||
|
||||
public RayMultiWorkerNativeRuntime(RayConfig rayConfig) {
|
||||
Preconditions.checkState(
|
||||
rayConfig.runMode == RunMode.CLUSTER && rayConfig.workerMode == WorkerType.WORKER);
|
||||
Preconditions.checkState(rayConfig.numWorkersPerProcess > 0,
|
||||
"numWorkersPerProcess must be greater than 0.");
|
||||
numWorkers = rayConfig.numWorkersPerProcess;
|
||||
runtimes = new RayNativeRuntime[numWorkers];
|
||||
threads = new Thread[numWorkers];
|
||||
|
||||
LOGGER.info("Starting {} workers.", numWorkers);
|
||||
|
||||
for (int i = 0; i < numWorkers; i++) {
|
||||
final int workerIndex = i;
|
||||
threads[i] = new Thread(() -> {
|
||||
RayNativeRuntime runtime = new RayNativeRuntime(rayConfig);
|
||||
runtimes[workerIndex] = runtime;
|
||||
currentThreadRuntime.set(runtime);
|
||||
runtime.run();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public void run() {
|
||||
for (int i = 0; i < numWorkers; i++) {
|
||||
threads[i].start();
|
||||
}
|
||||
for (int i = 0; i < numWorkers; i++) {
|
||||
try {
|
||||
threads[i].join();
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void shutdown() {
|
||||
for (int i = 0; i < numWorkers; i++) {
|
||||
runtimes[i].shutdown();
|
||||
}
|
||||
for (int i = 0; i < numWorkers; i++) {
|
||||
try {
|
||||
threads[i].join();
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public RayNativeRuntime getCurrentRuntime() {
|
||||
RayNativeRuntime currentRuntime = currentThreadRuntime.get();
|
||||
Preconditions.checkNotNull(currentRuntime,
|
||||
"RayRuntime is not set on current thread."
|
||||
+ " If you want to use Ray API in your own threads,"
|
||||
+ " please wrap your `Runnable`s or `Callable`s with"
|
||||
+ " `Ray.wrapRunnable` or `Ray.wrapCallable`.");
|
||||
return currentRuntime;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> RayObject<T> put(T obj) {
|
||||
return getCurrentRuntime().put(obj);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> T get(ObjectId objectId) {
|
||||
return getCurrentRuntime().get(objectId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> List<T> get(List<ObjectId> objectIds) {
|
||||
return getCurrentRuntime().get(objectIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
|
||||
return getCurrentRuntime().wait(waitList, numReturns, timeoutMs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void free(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
getCurrentRuntime().free(objectIds, localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
getCurrentRuntime().setResource(resourceName, capacity, nodeId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject call(RayFunc func, Object[] args, CallOptions options) {
|
||||
return getCurrentRuntime().call(func, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject call(RayFunc func, RayActor<?> actor, Object[] args) {
|
||||
return getCurrentRuntime().call(func, actor, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> RayActor<T> createActor(RayFunc actorFactoryFunc, Object[] args,
|
||||
ActorCreationOptions options) {
|
||||
return getCurrentRuntime().createActor(actorFactoryFunc, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RuntimeContext getRuntimeContext() {
|
||||
return getCurrentRuntime().getRuntimeContext();
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callPy(String moduleName, String functionName, Object[] args,
|
||||
CallOptions options) {
|
||||
return getCurrentRuntime().callPy(moduleName, functionName, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callPy(RayPyActor pyActor, String functionName, Object[] args) {
|
||||
return getCurrentRuntime().callPy(pyActor, functionName, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayPyActor createPyActor(String moduleName, String className, Object[] args,
|
||||
ActorCreationOptions options) {
|
||||
return getCurrentRuntime().createPyActor(moduleName, className, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Runnable wrapRunnable(Runnable runnable) {
|
||||
RayNativeRuntime runtime = getCurrentRuntime();
|
||||
return () -> {
|
||||
currentThreadRuntime.set(runtime);
|
||||
runnable.run();
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public Callable wrapCallable(Callable callable) {
|
||||
RayNativeRuntime runtime = getCurrentRuntime();
|
||||
return () -> {
|
||||
currentThreadRuntime.set(runtime);
|
||||
return callable.call();
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -2,9 +2,12 @@ package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Strings;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.lang.reflect.Field;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.NativeWorkerContext;
|
||||
@@ -46,15 +49,20 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
}
|
||||
LOGGER.debug("Native libraries loaded.");
|
||||
}
|
||||
nativeSetup(RayConfig.create().logDir);
|
||||
|
||||
RayConfig globalRayConfig = RayConfig.create();
|
||||
resetLibraryPath(globalRayConfig);
|
||||
|
||||
try {
|
||||
FileUtils.forceMkdir(new File(globalRayConfig.logDir));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to create the log directory.", e);
|
||||
}
|
||||
nativeSetup(globalRayConfig.logDir);
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(RayNativeRuntime::nativeShutdownHook));
|
||||
}
|
||||
|
||||
public RayNativeRuntime(RayConfig rayConfig) {
|
||||
super(rayConfig);
|
||||
}
|
||||
|
||||
protected void resetLibraryPath() {
|
||||
private static void resetLibraryPath(RayConfig rayConfig) {
|
||||
if (rayConfig.libraryPath.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
@@ -81,10 +89,11 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
public RayNativeRuntime(RayConfig rayConfig) {
|
||||
super(rayConfig);
|
||||
|
||||
// Reset library path at runtime.
|
||||
resetLibraryPath();
|
||||
resetLibraryPath(rayConfig);
|
||||
|
||||
if (rayConfig.getRedisAddress() == null) {
|
||||
manager = new RunManager(rayConfig);
|
||||
|
||||
@@ -62,6 +62,8 @@ public class RayConfig {
|
||||
*/
|
||||
public final int numberExecThreadsForDevRuntime;
|
||||
|
||||
public final int numWorkersPerProcess;
|
||||
|
||||
private void validate() {
|
||||
if (workerMode == WorkerType.WORKER) {
|
||||
Preconditions.checkArgument(redisAddress != null,
|
||||
@@ -171,6 +173,8 @@ public class RayConfig {
|
||||
// Number of threads that execute tasks.
|
||||
numberExecThreadsForDevRuntime = config.getInt("ray.dev-runtime.execution-parallelism");
|
||||
|
||||
numWorkersPerProcess = config.getInt("ray.raylet.config.num_workers_per_process_java");
|
||||
|
||||
// Validate config.
|
||||
validate();
|
||||
LOGGER.debug("Created config: {}", this);
|
||||
|
||||
@@ -320,10 +320,13 @@ public class RunManager {
|
||||
cmd.add("-Dray.redis.password=" + rayConfig.headRedisPassword);
|
||||
}
|
||||
|
||||
// Number of workers per Java worker process
|
||||
cmd.add("-Dray.raylet.config.num_workers_per_process_java=RAY_WORKER_NUM_WORKERS_PLACEHOLDER");
|
||||
|
||||
cmd.addAll(rayConfig.jvmParameters);
|
||||
|
||||
// jvm options
|
||||
cmd.add("RAY_WORKER_OPTION_0");
|
||||
cmd.add("RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_0");
|
||||
|
||||
// Main class
|
||||
cmd.add(WORKER_CLASS);
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package org.ray.runtime.runner.worker;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.RayMultiWorkerNativeRuntime;
|
||||
import org.ray.runtime.RayNativeRuntime;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -20,7 +22,14 @@ public class DefaultWorker {
|
||||
});
|
||||
Ray.init();
|
||||
LOGGER.info("Worker started.");
|
||||
((RayNativeRuntime)Ray.internal()).run();
|
||||
RayRuntime runtime = Ray.internal();
|
||||
if (runtime instanceof RayNativeRuntime) {
|
||||
((RayNativeRuntime)runtime).run();
|
||||
} else if (runtime instanceof RayMultiWorkerNativeRuntime) {
|
||||
((RayMultiWorkerNativeRuntime)runtime).run();
|
||||
} else {
|
||||
throw new RuntimeException("Unknown RayRuntime: " + runtime);
|
||||
}
|
||||
} catch (Exception e) {
|
||||
LOGGER.error("Failed to start worker.", e);
|
||||
}
|
||||
|
||||
@@ -5,7 +5,9 @@ import java.util.List;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayMultiWorkerNativeRuntime;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.object.ObjectStore;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
@@ -41,7 +43,11 @@ public class ArgumentsBuilder {
|
||||
} else {
|
||||
byte[] serialized = Serializer.encode(arg);
|
||||
if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) {
|
||||
id = ((AbstractRayRuntime) Ray.internal()).getObjectStore()
|
||||
RayRuntime runtime = Ray.internal();
|
||||
if (runtime instanceof RayMultiWorkerNativeRuntime) {
|
||||
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
|
||||
}
|
||||
id = ((AbstractRayRuntime) runtime).getObjectStore()
|
||||
.put(new NativeRayObject(serialized, null));
|
||||
} else {
|
||||
data = serialized;
|
||||
|
||||
@@ -87,6 +87,7 @@ ray {
|
||||
|
||||
// See src/ray/ray_config_def.h for options.
|
||||
config {
|
||||
num_workers_per_process_java: 10
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package org.ray.api;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.function.Supplier;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayMultiWorkerNativeRuntime;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.testng.Assert;
|
||||
import org.testng.SkipException;
|
||||
@@ -12,8 +15,7 @@ public class TestUtils {
|
||||
private static final int WAIT_INTERVAL_MS = 5;
|
||||
|
||||
public static void skipTestUnderSingleProcess() {
|
||||
AbstractRayRuntime runtime = (AbstractRayRuntime)Ray.internal();
|
||||
if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
|
||||
if (getRuntime().getRayConfig().runMode == RunMode.SINGLE_PROCESS) {
|
||||
throw new SkipException("This test doesn't work under single-process mode.");
|
||||
}
|
||||
}
|
||||
@@ -62,4 +64,13 @@ public class TestUtils {
|
||||
RayObject<String> obj = Ray.call(TestUtils::hi);
|
||||
Assert.assertEquals(obj.get(), "hi");
|
||||
}
|
||||
|
||||
public static AbstractRayRuntime getRuntime() {
|
||||
RayRuntime runtime = Ray.internal();
|
||||
if (runtime instanceof RayMultiWorkerNativeRuntime) {
|
||||
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
|
||||
}
|
||||
Preconditions.checkState(runtime instanceof AbstractRayRuntime);
|
||||
return (AbstractRayRuntime) runtime;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import org.ray.api.TestUtils;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.exception.UnreconstructableException;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.actor.NativeRayActor;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.testng.Assert;
|
||||
@@ -117,7 +116,7 @@ public class ActorTest extends BaseTest {
|
||||
Ray.internal().free(ImmutableList.of(value.getId()), false, false);
|
||||
// Wait until the object is deleted, because the above free operation is async.
|
||||
while (true) {
|
||||
NativeRayObject result = ((AbstractRayRuntime) Ray.internal()).getObjectStore()
|
||||
NativeRayObject result = TestUtils.getRuntime().getObjectStore()
|
||||
.getRaw(ImmutableList.of(value.getId()), 0).get(0);
|
||||
if (result == null) {
|
||||
break;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableMap;
|
||||
import java.io.File;
|
||||
@@ -63,7 +64,7 @@ public abstract class BaseMultiLanguageTest {
|
||||
// Start ray cluster.
|
||||
String workerOptions =
|
||||
" -classpath " + System.getProperty("java.class.path");
|
||||
final List<String> startCommand = ImmutableList.of(
|
||||
List<String> startCommand = ImmutableList.of(
|
||||
"ray",
|
||||
"start",
|
||||
"--head",
|
||||
@@ -74,6 +75,13 @@ public abstract class BaseMultiLanguageTest {
|
||||
"--include-java",
|
||||
"--java-worker-options=" + workerOptions
|
||||
);
|
||||
String numWorkersPerProcessJava = System
|
||||
.getProperty("ray.raylet.config.num_workers_per_process_java");
|
||||
if (!Strings.isNullOrEmpty(numWorkersPerProcessJava)) {
|
||||
startCommand = ImmutableList.<String>builder().addAll(startCommand)
|
||||
.add(String.format("--internal-config={\"num_workers_per_process_java\": %s}",
|
||||
numWorkersPerProcessJava)).build();
|
||||
}
|
||||
if (!executeCommand(startCommand, 10, getRayStartEnv())) {
|
||||
throw new RuntimeException("Couldn't start ray cluster.");
|
||||
}
|
||||
|
||||
@@ -127,6 +127,7 @@ public class FailureTest extends BaseTest {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
List<RayFunc0<Integer>> badFunctions = Arrays.asList(FailureTest::badFunc,
|
||||
FailureTest::badFunc2);
|
||||
TestUtils.warmUpCluster();
|
||||
for (RayFunc0<Integer> badFunc : badFunctions) {
|
||||
RayObject<Integer> obj1 = Ray.call(badFunc);
|
||||
RayObject<Integer> obj2 = Ray.call(FailureTest::slowFunc);
|
||||
|
||||
@@ -2,11 +2,9 @@ package org.ray.api.test;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.List;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.runtimecontext.NodeInfo;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.gcs.GcsClient;
|
||||
import org.testng.Assert;
|
||||
@@ -29,10 +27,10 @@ public class GcsClientTest extends BaseTest {
|
||||
@Test
|
||||
public void testGetAllNodeInfo() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
RayConfig config = ((AbstractRayRuntime)Ray.internal()).getRayConfig();
|
||||
RayConfig config = TestUtils.getRuntime().getRayConfig();
|
||||
|
||||
Preconditions.checkNotNull(config);
|
||||
GcsClient gcsClient = ((AbstractRayRuntime)Ray.internal()).getGcsClient();
|
||||
GcsClient gcsClient = TestUtils.getRuntime().getGcsClient();
|
||||
List<NodeInfo> allNodeInfo = gcsClient.getAllNodeInfo();
|
||||
Assert.assertEquals(allNodeInfo.size(), 1);
|
||||
Assert.assertEquals(allNodeInfo.get(0).nodeAddress, config.nodeIp);
|
||||
@@ -43,11 +41,11 @@ public class GcsClientTest extends BaseTest {
|
||||
@Test
|
||||
public void testNextJob() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
RayConfig config = ((AbstractRayRuntime)Ray.internal()).getRayConfig();
|
||||
RayConfig config = TestUtils.getRuntime().getRayConfig();
|
||||
// The value of job id of this driver in cluster should be `1L`.
|
||||
Assert.assertEquals(config.getJobId(), JobId.fromInt(1));
|
||||
|
||||
GcsClient gcsClient = ((AbstractRayRuntime)Ray.internal()).getGcsClient();
|
||||
GcsClient gcsClient = TestUtils.getRuntime().getGcsClient();
|
||||
for (int i = 2; i < 100; ++i) {
|
||||
Assert.assertEquals(gcsClient.nextJobId(), JobId.fromInt(i));
|
||||
}
|
||||
|
||||
@@ -54,19 +54,22 @@ public class MultiThreadingTest extends BaseTest {
|
||||
}
|
||||
|
||||
@RayRemote
|
||||
public ActorId getCurrentActorId() {
|
||||
final ActorId[] result = new ActorId[1];
|
||||
Thread thread = new Thread(() -> {
|
||||
result[0] = Ray.getRuntimeContext().getCurrentActorId();
|
||||
});
|
||||
public ActorId getCurrentActorId() throws Exception {
|
||||
final Object[] result = new Object[1];
|
||||
Thread thread = new Thread(Ray.wrapRunnable(() -> {
|
||||
try {
|
||||
result[0] = Ray.getRuntimeContext().getCurrentActorId();
|
||||
} catch (Exception e) {
|
||||
result[0] = e;
|
||||
}
|
||||
}));
|
||||
thread.start();
|
||||
try {
|
||||
thread.join();
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
thread.join();
|
||||
if (result[0] instanceof Exception) {
|
||||
throw (Exception) result[0];
|
||||
}
|
||||
Assert.assertEquals(result[0], actorId);
|
||||
return result[0];
|
||||
return (ActorId) result[0];
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,13 +150,13 @@ public class MultiThreadingTest extends BaseTest {
|
||||
try {
|
||||
List<Future<String>> futures = new ArrayList<>();
|
||||
for (int i = 0; i < NUM_THREADS; i++) {
|
||||
Callable<String> task = () -> {
|
||||
Callable<String> task = Ray.wrapCallable(() -> {
|
||||
for (int j = 0; j < numRepeats; j++) {
|
||||
TimeUnit.MILLISECONDS.sleep(1);
|
||||
testCase.run();
|
||||
}
|
||||
return "ok";
|
||||
};
|
||||
});
|
||||
futures.add(service.submit(task));
|
||||
}
|
||||
for (Future<String> future : futures) {
|
||||
|
||||
@@ -7,7 +7,6 @@ import org.ray.api.RayObject;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@@ -26,7 +25,7 @@ public class PlasmaFreeTest extends BaseTest {
|
||||
Ray.internal().free(ImmutableList.of(helloId.getId()), true, false);
|
||||
|
||||
final boolean result = TestUtils.waitForCondition(() ->
|
||||
((AbstractRayRuntime) Ray.internal()).getObjectStore()
|
||||
TestUtils.getRuntime().getObjectStore()
|
||||
.getRaw(ImmutableList.of(helloId.getId()), 0).get(0) == null, 50);
|
||||
Assert.assertTrue(result);
|
||||
}
|
||||
@@ -40,7 +39,7 @@ public class PlasmaFreeTest extends BaseTest {
|
||||
|
||||
TaskId taskId = TaskId.fromBytes(Arrays.copyOf(helloId.getId().getBytes(), TaskId.LENGTH));
|
||||
final boolean result = TestUtils.waitForCondition(
|
||||
() -> !(((AbstractRayRuntime) Ray.internal()).getGcsClient())
|
||||
() -> !TestUtils.getRuntime().getGcsClient()
|
||||
.rayletTaskExistsInGcs(taskId), 50);
|
||||
Assert.assertTrue(result);
|
||||
}
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import java.util.Collections;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.object.ObjectStore;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
@@ -16,8 +13,7 @@ public class PlasmaStoreTest extends BaseTest {
|
||||
public void testPutWithDuplicateId() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
ObjectId objectId = ObjectId.fromRandom();
|
||||
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
|
||||
ObjectStore objectStore = runtime.getObjectStore();
|
||||
ObjectStore objectStore = TestUtils.getRuntime().getObjectStore();
|
||||
objectStore.put("1", objectId);
|
||||
Assert.assertEquals(Ray.get(objectId), "1");
|
||||
objectStore.put("2", objectId);
|
||||
|
||||
@@ -6,9 +6,9 @@ import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@@ -80,7 +80,7 @@ public class RayCallTest extends BaseTest {
|
||||
@RayRemote
|
||||
private static void testNoReturn(ObjectId objectId) {
|
||||
// Put an object in object store to inform driver that this function is executing.
|
||||
((AbstractRayRuntime) Ray.internal()).getObjectStore().put(1, objectId);
|
||||
TestUtils.getRuntime().getObjectStore().put(1, objectId);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -2,8 +2,8 @@ package org.ray.api.test;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.object.ObjectStore;
|
||||
import org.testng.Assert;
|
||||
@@ -14,7 +14,7 @@ public class RaySerializerTest extends BaseMultiLanguageTest {
|
||||
@Test
|
||||
public void testSerializePyActor() {
|
||||
RayPyActor pyActor = Ray.createPyActor("test", "RaySerializerTest");
|
||||
ObjectStore objectStore = ((AbstractRayRuntime) Ray.internal()).getObjectStore();
|
||||
ObjectStore objectStore = TestUtils.getRuntime().getObjectStore();
|
||||
NativeRayObject nativeRayObject = objectStore.serialize(pyActor);
|
||||
RayPyActor result = (RayPyActor) objectStore
|
||||
.deserialize(nativeRayObject, ObjectId.fromRandom());
|
||||
|
||||
Reference in New Issue
Block a user