[Java] Support multiple workers in Java worker process (#5505)

This commit is contained in:
Kai Yang
2019-09-07 22:52:05 +08:00
committed by Hao Chen
parent d89ceb3ee5
commit 732336fc4f
37 changed files with 512 additions and 148 deletions
@@ -4,6 +4,7 @@ import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.concurrent.Callable;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
@@ -59,11 +60,6 @@ public abstract class AbstractRayRuntime implements RayRuntime {
runtimeContext = new RuntimeContextImpl(this);
}
/**
* Start runtime.
*/
public abstract void start() throws Exception;
@Override
public abstract void shutdown();
@@ -168,6 +164,16 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return (RayPyActor) createActorImpl(functionDescriptor, args, options);
}
@Override
public Runnable wrapRunnable(Runnable runnable) {
return runnable;
}
@Override
public Callable wrapCallable(Callable callable) {
return callable;
}
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
Object[] args, int numReturns, CallOptions options) {
List<FunctionArg> functionArgs = ArgumentsBuilder
@@ -4,11 +4,12 @@ import org.ray.api.runtime.RayRuntime;
import org.ray.api.runtime.RayRuntimeFactory;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.generated.Common.WorkerType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* The default Ray runtime factory. It produces an instance of AbstractRayRuntime.
* The default Ray runtime factory. It produces an instance of RayRuntime.
*/
public class DefaultRayRuntimeFactory implements RayRuntimeFactory {
@@ -18,14 +19,16 @@ public class DefaultRayRuntimeFactory implements RayRuntimeFactory {
public RayRuntime createRayRuntime() {
RayConfig rayConfig = RayConfig.create();
try {
AbstractRayRuntime runtime;
RayRuntime runtime;
if (rayConfig.runMode == RunMode.SINGLE_PROCESS) {
runtime = new RayDevRuntime(rayConfig);
} else {
runtime = new RayNativeRuntime(rayConfig);
if (rayConfig.workerMode == WorkerType.DRIVER) {
runtime = new RayNativeRuntime(rayConfig);
} else {
runtime = new RayMultiWorkerNativeRuntime(rayConfig);
}
}
runtime.start();
return runtime;
} catch (Exception e) {
LOGGER.error("Failed to initialize ray runtime", e);
@@ -11,14 +11,10 @@ import org.ray.runtime.task.TaskExecutor;
public class RayDevRuntime extends AbstractRayRuntime {
public RayDevRuntime(RayConfig rayConfig) {
super(rayConfig);
}
private AtomicInteger jobCounter = new AtomicInteger(0);
@Override
public void start() {
public RayDevRuntime(RayConfig rayConfig) {
super(rayConfig);
if (rayConfig.getJobId().isNil()) {
rayConfig.setJobId(nextJobId());
}
@@ -0,0 +1,192 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import java.util.List;
import java.util.concurrent.Callable;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
import org.ray.api.WaitResult;
import org.ray.api.function.RayFunc;
import org.ray.api.id.ObjectId;
import org.ray.api.id.UniqueId;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.CallOptions;
import org.ray.api.runtime.RayRuntime;
import org.ray.api.runtimecontext.RuntimeContext;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.generated.Common.WorkerType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This is a proxy runtime for multi-worker support. It holds multiple {@link RayNativeRuntime}
* instances and redirect calls to the correct one based on thread context.
*/
public class RayMultiWorkerNativeRuntime implements RayRuntime {
private static final Logger LOGGER = LoggerFactory.getLogger(RayMultiWorkerNativeRuntime.class);
/**
* The number of workers per worker process.
*/
private final int numWorkers;
/**
* The worker threads.
*/
private final Thread[] threads;
/**
* The {@link RayNativeRuntime} instances of workers.
*/
private final RayNativeRuntime[] runtimes;
/**
* The {@link RayNativeRuntime} instance of current thread.
*/
private final ThreadLocal<RayNativeRuntime> currentThreadRuntime = new ThreadLocal<>();
public RayMultiWorkerNativeRuntime(RayConfig rayConfig) {
Preconditions.checkState(
rayConfig.runMode == RunMode.CLUSTER && rayConfig.workerMode == WorkerType.WORKER);
Preconditions.checkState(rayConfig.numWorkersPerProcess > 0,
"numWorkersPerProcess must be greater than 0.");
numWorkers = rayConfig.numWorkersPerProcess;
runtimes = new RayNativeRuntime[numWorkers];
threads = new Thread[numWorkers];
LOGGER.info("Starting {} workers.", numWorkers);
for (int i = 0; i < numWorkers; i++) {
final int workerIndex = i;
threads[i] = new Thread(() -> {
RayNativeRuntime runtime = new RayNativeRuntime(rayConfig);
runtimes[workerIndex] = runtime;
currentThreadRuntime.set(runtime);
runtime.run();
});
}
}
public void run() {
for (int i = 0; i < numWorkers; i++) {
threads[i].start();
}
for (int i = 0; i < numWorkers; i++) {
try {
threads[i].join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
@Override
public void shutdown() {
for (int i = 0; i < numWorkers; i++) {
runtimes[i].shutdown();
}
for (int i = 0; i < numWorkers; i++) {
try {
threads[i].join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
public RayNativeRuntime getCurrentRuntime() {
RayNativeRuntime currentRuntime = currentThreadRuntime.get();
Preconditions.checkNotNull(currentRuntime,
"RayRuntime is not set on current thread."
+ " If you want to use Ray API in your own threads,"
+ " please wrap your `Runnable`s or `Callable`s with"
+ " `Ray.wrapRunnable` or `Ray.wrapCallable`.");
return currentRuntime;
}
@Override
public <T> RayObject<T> put(T obj) {
return getCurrentRuntime().put(obj);
}
@Override
public <T> T get(ObjectId objectId) {
return getCurrentRuntime().get(objectId);
}
@Override
public <T> List<T> get(List<ObjectId> objectIds) {
return getCurrentRuntime().get(objectIds);
}
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
return getCurrentRuntime().wait(waitList, numReturns, timeoutMs);
}
@Override
public void free(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
getCurrentRuntime().free(objectIds, localOnly, deleteCreatingTasks);
}
@Override
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
getCurrentRuntime().setResource(resourceName, capacity, nodeId);
}
@Override
public RayObject call(RayFunc func, Object[] args, CallOptions options) {
return getCurrentRuntime().call(func, args, options);
}
@Override
public RayObject call(RayFunc func, RayActor<?> actor, Object[] args) {
return getCurrentRuntime().call(func, actor, args);
}
@Override
public <T> RayActor<T> createActor(RayFunc actorFactoryFunc, Object[] args,
ActorCreationOptions options) {
return getCurrentRuntime().createActor(actorFactoryFunc, args, options);
}
@Override
public RuntimeContext getRuntimeContext() {
return getCurrentRuntime().getRuntimeContext();
}
@Override
public RayObject callPy(String moduleName, String functionName, Object[] args,
CallOptions options) {
return getCurrentRuntime().callPy(moduleName, functionName, args, options);
}
@Override
public RayObject callPy(RayPyActor pyActor, String functionName, Object[] args) {
return getCurrentRuntime().callPy(pyActor, functionName, args);
}
@Override
public RayPyActor createPyActor(String moduleName, String className, Object[] args,
ActorCreationOptions options) {
return getCurrentRuntime().createPyActor(moduleName, className, args, options);
}
@Override
public Runnable wrapRunnable(Runnable runnable) {
RayNativeRuntime runtime = getCurrentRuntime();
return () -> {
currentThreadRuntime.set(runtime);
runnable.run();
};
}
@Override
public Callable wrapCallable(Callable callable) {
RayNativeRuntime runtime = getCurrentRuntime();
return () -> {
currentThreadRuntime.set(runtime);
return callable.call();
};
}
}
@@ -2,9 +2,12 @@ package org.ray.runtime;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.ray.api.id.JobId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.context.NativeWorkerContext;
@@ -46,15 +49,20 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
}
LOGGER.debug("Native libraries loaded.");
}
nativeSetup(RayConfig.create().logDir);
RayConfig globalRayConfig = RayConfig.create();
resetLibraryPath(globalRayConfig);
try {
FileUtils.forceMkdir(new File(globalRayConfig.logDir));
} catch (IOException e) {
throw new RuntimeException("Failed to create the log directory.", e);
}
nativeSetup(globalRayConfig.logDir);
Runtime.getRuntime().addShutdownHook(new Thread(RayNativeRuntime::nativeShutdownHook));
}
public RayNativeRuntime(RayConfig rayConfig) {
super(rayConfig);
}
protected void resetLibraryPath() {
private static void resetLibraryPath(RayConfig rayConfig) {
if (rayConfig.libraryPath.isEmpty()) {
return;
}
@@ -81,10 +89,11 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
}
}
@Override
public void start() {
public RayNativeRuntime(RayConfig rayConfig) {
super(rayConfig);
// Reset library path at runtime.
resetLibraryPath();
resetLibraryPath(rayConfig);
if (rayConfig.getRedisAddress() == null) {
manager = new RunManager(rayConfig);
@@ -62,6 +62,8 @@ public class RayConfig {
*/
public final int numberExecThreadsForDevRuntime;
public final int numWorkersPerProcess;
private void validate() {
if (workerMode == WorkerType.WORKER) {
Preconditions.checkArgument(redisAddress != null,
@@ -171,6 +173,8 @@ public class RayConfig {
// Number of threads that execute tasks.
numberExecThreadsForDevRuntime = config.getInt("ray.dev-runtime.execution-parallelism");
numWorkersPerProcess = config.getInt("ray.raylet.config.num_workers_per_process_java");
// Validate config.
validate();
LOGGER.debug("Created config: {}", this);
@@ -320,10 +320,13 @@ public class RunManager {
cmd.add("-Dray.redis.password=" + rayConfig.headRedisPassword);
}
// Number of workers per Java worker process
cmd.add("-Dray.raylet.config.num_workers_per_process_java=RAY_WORKER_NUM_WORKERS_PLACEHOLDER");
cmd.addAll(rayConfig.jvmParameters);
// jvm options
cmd.add("RAY_WORKER_OPTION_0");
cmd.add("RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_0");
// Main class
cmd.add(WORKER_CLASS);
@@ -1,6 +1,8 @@
package org.ray.runtime.runner.worker;
import org.ray.api.Ray;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.RayNativeRuntime;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -20,7 +22,14 @@ public class DefaultWorker {
});
Ray.init();
LOGGER.info("Worker started.");
((RayNativeRuntime)Ray.internal()).run();
RayRuntime runtime = Ray.internal();
if (runtime instanceof RayNativeRuntime) {
((RayNativeRuntime)runtime).run();
} else if (runtime instanceof RayMultiWorkerNativeRuntime) {
((RayMultiWorkerNativeRuntime)runtime).run();
} else {
throw new RuntimeException("Unknown RayRuntime: " + runtime);
}
} catch (Exception e) {
LOGGER.error("Failed to start worker.", e);
}
@@ -5,7 +5,9 @@ import java.util.List;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.id.ObjectId;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.object.ObjectStore;
import org.ray.runtime.util.Serializer;
@@ -41,7 +43,11 @@ public class ArgumentsBuilder {
} else {
byte[] serialized = Serializer.encode(arg);
if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) {
id = ((AbstractRayRuntime) Ray.internal()).getObjectStore()
RayRuntime runtime = Ray.internal();
if (runtime instanceof RayMultiWorkerNativeRuntime) {
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
}
id = ((AbstractRayRuntime) runtime).getObjectStore()
.put(new NativeRayObject(serialized, null));
} else {
data = serialized;
@@ -87,6 +87,7 @@ ray {
// See src/ray/ray_config_def.h for options.
config {
num_workers_per_process_java: 10
}
}