Support multiple core workers in one process (#7623)

This commit is contained in:
Kai Yang
2020-04-07 11:01:47 +08:00
committed by GitHub
parent e91595f955
commit 48b48cc8c2
90 changed files with 2014 additions and 1411 deletions
@@ -19,7 +19,6 @@ import org.ray.api.function.RayFuncVoid;
import org.ray.api.id.ObjectId;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.CallOptions;
import org.ray.api.runtime.RayRuntime;
import org.ray.api.runtimecontext.RuntimeContext;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.context.RuntimeContextImpl;
@@ -29,6 +28,7 @@ import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
import org.ray.runtime.gcs.GcsClient;
import org.ray.runtime.generated.Common.Language;
import org.ray.runtime.generated.Common.WorkerType;
import org.ray.runtime.object.ObjectStore;
import org.ray.runtime.object.RayObjectImpl;
import org.ray.runtime.task.ArgumentsBuilder;
@@ -41,7 +41,7 @@ import org.slf4j.LoggerFactory;
/**
* Core functionality to implement Ray APIs.
*/
public abstract class AbstractRayRuntime implements RayRuntime {
public abstract class AbstractRayRuntime implements RayRuntimeInternal {
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
public static final String PYTHON_INIT_METHOD_NAME = "__init__";
@@ -55,9 +55,15 @@ public abstract class AbstractRayRuntime implements RayRuntime {
protected TaskSubmitter taskSubmitter;
protected WorkerContext workerContext;
public AbstractRayRuntime(RayConfig rayConfig, FunctionManager functionManager) {
/**
* Whether the required thread context is set on the current thread.
*/
final ThreadLocal<Boolean> isContextSet = ThreadLocal.withInitial(() -> false);
public AbstractRayRuntime(RayConfig rayConfig) {
this.rayConfig = rayConfig;
this.functionManager = functionManager;
setIsContextSet(rayConfig.workerMode == WorkerType.DRIVER);
functionManager = new FunctionManager(rayConfig.jobResourcePath);
runtimeContext = new RuntimeContextImpl(this);
}
@@ -161,13 +167,54 @@ public abstract class AbstractRayRuntime implements RayRuntime {
}
@Override
public Runnable wrapRunnable(Runnable runnable) {
return runnable;
public void setAsyncContext(Object asyncContext) {
isContextSet.set(true);
}
// TODO (kfstorm): Simplify the duplicate code in wrap*** methods.
@Override
public final Runnable wrapRunnable(Runnable runnable) {
Object asyncContext = getAsyncContext();
return () -> {
boolean oldIsContextSet = isContextSet.get();
Object oldAsyncContext = null;
if (oldIsContextSet) {
oldAsyncContext = getAsyncContext();
}
setAsyncContext(asyncContext);
try {
runnable.run();
} finally {
if (oldIsContextSet) {
setAsyncContext(oldAsyncContext);
} else {
setIsContextSet(false);
}
}
};
}
@Override
public Callable wrapCallable(Callable callable) {
return callable;
public final <T> Callable<T> wrapCallable(Callable<T> callable) {
Object asyncContext = getAsyncContext();
return () -> {
boolean oldIsContextSet = isContextSet.get();
Object oldAsyncContext = null;
if (oldIsContextSet) {
oldAsyncContext = getAsyncContext();
}
setAsyncContext(asyncContext);
try {
return callable.call();
} finally {
if (oldIsContextSet) {
setAsyncContext(oldAsyncContext);
} else {
setIsContextSet(false);
}
}
};
}
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
@@ -209,18 +256,22 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return actor;
}
@Override
public WorkerContext getWorkerContext() {
return workerContext;
}
@Override
public ObjectStore getObjectStore() {
return objectStore;
}
@Override
public FunctionManager getFunctionManager() {
return functionManager;
}
@Override
public RayConfig getRayConfig() {
return rayConfig;
}
@@ -229,7 +280,13 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return runtimeContext;
}
@Override
public GcsClient getGcsClient() {
return gcsClient;
}
@Override
public void setIsContextSet(boolean isContextSet) {
this.isContextSet.set(isContextSet);
}
}
@@ -4,8 +4,6 @@ import org.ray.api.runtime.RayRuntime;
import org.ray.api.runtime.RayRuntimeFactory;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.generated.Common.WorkerType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -20,17 +18,13 @@ public class DefaultRayRuntimeFactory implements RayRuntimeFactory {
public RayRuntime createRayRuntime() {
RayConfig rayConfig = RayConfig.getInstance();
try {
FunctionManager functionManager = new FunctionManager(rayConfig.jobResourcePath);
RayRuntime runtime;
if (rayConfig.runMode == RunMode.SINGLE_PROCESS) {
runtime = new RayDevRuntime(rayConfig, functionManager);
} else {
if (rayConfig.workerMode == WorkerType.DRIVER) {
runtime = new RayNativeRuntime(rayConfig, functionManager);
} else {
runtime = new RayMultiWorkerNativeRuntime(rayConfig, functionManager);
}
}
AbstractRayRuntime innerRuntime = rayConfig.runMode == RunMode.SINGLE_PROCESS
? new RayDevRuntime(rayConfig)
: new RayNativeRuntime(rayConfig);
RayRuntimeInternal runtime = rayConfig.numWorkersPerProcess > 1
? RayRuntimeProxy.newInstance(innerRuntime)
: innerRuntime;
runtime.start();
return runtime;
} catch (Exception e) {
LOGGER.error("Failed to initialize ray runtime", e);
@@ -1,12 +1,12 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import java.util.concurrent.atomic.AtomicInteger;
import org.ray.api.BaseActor;
import org.ray.api.id.JobId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.context.LocalModeWorkerContext;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.object.LocalModeObjectStore;
import org.ray.runtime.task.LocalModeTaskExecutor;
import org.ray.runtime.task.LocalModeTaskSubmitter;
@@ -19,18 +19,31 @@ public class RayDevRuntime extends AbstractRayRuntime {
private AtomicInteger jobCounter = new AtomicInteger(0);
public RayDevRuntime(RayConfig rayConfig, FunctionManager functionManager) {
super(rayConfig, functionManager);
public RayDevRuntime(RayConfig rayConfig) {
super(rayConfig);
}
@Override
public void start() {
if (rayConfig.getJobId().isNil()) {
rayConfig.setJobId(nextJobId());
}
taskExecutor = new LocalModeTaskExecutor(this);
workerContext = new LocalModeWorkerContext(rayConfig.getJobId());
objectStore = new LocalModeObjectStore(workerContext);
taskSubmitter = new LocalModeTaskSubmitter(this, (LocalModeObjectStore) objectStore,
rayConfig.numberExecThreadsForDevRuntime);
taskSubmitter = new LocalModeTaskSubmitter(this, taskExecutor,
(LocalModeObjectStore) objectStore);
((LocalModeObjectStore) objectStore).addObjectPutCallback(
objectId -> ((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId));
objectId -> {
if (taskSubmitter != null) {
((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId);
}
});
}
@Override
public void run() {
throw new UnsupportedOperationException();
}
@Override
@@ -60,6 +73,8 @@ public class RayDevRuntime extends AbstractRayRuntime {
@Override
public void setAsyncContext(Object asyncContext) {
Preconditions.checkArgument(asyncContext == null);
super.setAsyncContext(asyncContext);
}
private JobId nextJobId() {
@@ -1,215 +0,0 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import java.util.List;
import java.util.concurrent.Callable;
import org.ray.api.BaseActor;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RayPyActor;
import org.ray.api.WaitResult;
import org.ray.api.function.PyActorClass;
import org.ray.api.function.PyActorMethod;
import org.ray.api.function.PyRemoteFunction;
import org.ray.api.function.RayFunc;
import org.ray.api.id.ObjectId;
import org.ray.api.id.UniqueId;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.CallOptions;
import org.ray.api.runtime.RayRuntime;
import org.ray.api.runtimecontext.RuntimeContext;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.generated.Common.WorkerType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This is a proxy runtime for multi-worker support. It holds multiple {@link RayNativeRuntime}
* instances and redirect calls to the correct one based on thread context.
*/
public class RayMultiWorkerNativeRuntime implements RayRuntime {
private static final Logger LOGGER = LoggerFactory.getLogger(RayMultiWorkerNativeRuntime.class);
private final FunctionManager functionManager;
/**
* The number of workers per worker process.
*/
private final int numWorkers;
/**
* The worker threads.
*/
private final Thread[] threads;
/**
* The {@link RayNativeRuntime} instances of workers.
*/
private final RayNativeRuntime[] runtimes;
/**
* The {@link RayNativeRuntime} instance of current thread.
*/
private final ThreadLocal<RayNativeRuntime> currentThreadRuntime = new ThreadLocal<>();
public RayMultiWorkerNativeRuntime(RayConfig rayConfig, FunctionManager functionManager) {
this.functionManager = functionManager;
Preconditions.checkState(
rayConfig.runMode == RunMode.CLUSTER && rayConfig.workerMode == WorkerType.WORKER);
Preconditions.checkState(rayConfig.numWorkersPerProcess > 0,
"numWorkersPerProcess must be greater than 0.");
numWorkers = rayConfig.numWorkersPerProcess;
runtimes = new RayNativeRuntime[numWorkers];
threads = new Thread[numWorkers];
LOGGER.info("Starting {} workers.", numWorkers);
for (int i = 0; i < numWorkers; i++) {
final int workerIndex = i;
threads[i] = new Thread(() -> {
RayNativeRuntime runtime = new RayNativeRuntime(rayConfig, functionManager);
runtimes[workerIndex] = runtime;
currentThreadRuntime.set(runtime);
runtime.run();
});
}
}
public void run() {
for (int i = 0; i < numWorkers; i++) {
threads[i].start();
}
for (int i = 0; i < numWorkers; i++) {
try {
threads[i].join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
@Override
public void shutdown() {
for (int i = 0; i < numWorkers; i++) {
runtimes[i].shutdown();
}
for (int i = 0; i < numWorkers; i++) {
try {
threads[i].join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
public RayNativeRuntime getCurrentRuntime() {
RayNativeRuntime currentRuntime = currentThreadRuntime.get();
Preconditions.checkNotNull(currentRuntime,
"RayRuntime is not set on current thread."
+ " If you want to use Ray API in your own threads,"
+ " please wrap your `Runnable`s or `Callable`s with"
+ " `Ray.wrapRunnable` or `Ray.wrapCallable`.");
return currentRuntime;
}
@Override
public <T> RayObject<T> put(T obj) {
return getCurrentRuntime().put(obj);
}
@Override
public <T> T get(ObjectId objectId) {
return getCurrentRuntime().get(objectId);
}
@Override
public <T> List<T> get(List<ObjectId> objectIds) {
return getCurrentRuntime().get(objectIds);
}
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
return getCurrentRuntime().wait(waitList, numReturns, timeoutMs);
}
@Override
public void free(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
getCurrentRuntime().free(objectIds, localOnly, deleteCreatingTasks);
}
@Override
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
getCurrentRuntime().setResource(resourceName, capacity, nodeId);
}
@Override
public void killActor(BaseActor actor, boolean noReconstruction) {
getCurrentRuntime().killActor(actor, noReconstruction);
}
@Override
public RayObject call(RayFunc func, Object[] args, CallOptions options) {
return getCurrentRuntime().call(func, args, options);
}
@Override
public RayObject call(PyRemoteFunction pyRemoteFunction, Object[] args,
CallOptions options) {
return getCurrentRuntime().call(pyRemoteFunction, args, options);
}
@Override
public RayObject callActor(RayActor<?> actor, RayFunc func, Object[] args) {
return getCurrentRuntime().callActor(actor, func, args);
}
@Override
public RayObject callActor(RayPyActor pyActor, PyActorMethod pyActorMethod, Object[] args) {
return getCurrentRuntime().callActor(pyActor, pyActorMethod, args);
}
@Override
public <T> RayActor<T> createActor(RayFunc actorFactoryFunc, Object[] args,
ActorCreationOptions options) {
return getCurrentRuntime().createActor(actorFactoryFunc, args, options);
}
@Override
public RayPyActor createActor(PyActorClass pyActorClass, Object[] args,
ActorCreationOptions options) {
return getCurrentRuntime().createActor(pyActorClass, args, options);
}
@Override
public RuntimeContext getRuntimeContext() {
return getCurrentRuntime().getRuntimeContext();
}
@Override
public Object getAsyncContext() {
return getCurrentRuntime();
}
@Override
public void setAsyncContext(Object asyncContext) {
currentThreadRuntime.set((RayNativeRuntime) asyncContext);
}
@Override
public Runnable wrapRunnable(Runnable runnable) {
Object asyncContext = getAsyncContext();
return () -> {
setAsyncContext(asyncContext);
runnable.run();
};
}
@Override
public Callable wrapCallable(Callable callable) {
Object asyncContext = getAsyncContext();
return () -> {
setAsyncContext(asyncContext);
return callable.call();
};
}
}
@@ -3,7 +3,6 @@ package org.ray.runtime;
import com.google.common.base.Preconditions;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.io.FileUtils;
import org.ray.api.BaseActor;
@@ -11,7 +10,6 @@ import org.ray.api.id.JobId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.context.NativeWorkerContext;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.gcs.GcsClient;
import org.ray.runtime.gcs.GcsClientOptions;
import org.ray.runtime.gcs.RedisClient;
@@ -34,10 +32,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
private RunManager manager = null;
/**
* The native pointer of core worker.
*/
private long nativeCoreWorkerPointer;
static {
LOGGER.debug("Loading native libraries.");
@@ -55,14 +49,13 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
JniUtils.loadLibrary("core_worker_library_java", true);
LOGGER.debug("Native libraries loaded.");
// Reset library path at runtime.
resetLibraryPath(rayConfig);
try {
FileUtils.forceMkdir(new File(rayConfig.logDir));
} catch (IOException e) {
throw new RuntimeException("Failed to create the log directory.", e);
}
nativeSetup(rayConfig.logDir, rayConfig.rayletConfigParameters);
Runtime.getRuntime().addShutdownHook(new Thread(RayNativeRuntime::nativeShutdownHook));
}
private static void resetLibraryPath(RayConfig rayConfig) {
@@ -71,11 +64,12 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
JniUtils.resetLibraryPath(libraryPath);
}
public RayNativeRuntime(RayConfig rayConfig, FunctionManager functionManager) {
super(rayConfig, functionManager);
// Reset library path at runtime.
resetLibraryPath(rayConfig);
public RayNativeRuntime(RayConfig rayConfig) {
super(rayConfig);
}
@Override
public void start() {
if (rayConfig.getRedisAddress() == null) {
manager = new RunManager(rayConfig);
manager.startRayProcesses(true);
@@ -86,21 +80,21 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
if (rayConfig.getJobId() == JobId.NIL) {
rayConfig.setJobId(gcsClient.nextJobId());
}
int numWorkersPerProcess =
rayConfig.workerMode == WorkerType.DRIVER ? 1 : rayConfig.numWorkersPerProcess;
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
nativeCoreWorkerPointer = nativeInitCoreWorker(rayConfig.workerMode.getNumber(),
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName,
nativeInitialize(rayConfig.workerMode.getNumber(),
rayConfig.nodeIp, rayConfig.getNodeManagerPort(),
rayConfig.workerMode == WorkerType.DRIVER ? System.getProperty("user.dir") : "",
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName,
(rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(),
new GcsClientOptions(rayConfig));
Preconditions.checkState(nativeCoreWorkerPointer != 0);
new GcsClientOptions(rayConfig), numWorkersPerProcess,
rayConfig.logDir, rayConfig.rayletConfigParameters);
workerContext = new NativeWorkerContext(nativeCoreWorkerPointer);
taskExecutor = new NativeTaskExecutor(nativeCoreWorkerPointer, this);
objectStore = new NativeObjectStore(workerContext, nativeCoreWorkerPointer);
taskSubmitter = new NativeTaskSubmitter(nativeCoreWorkerPointer);
// register
registerWorker();
taskExecutor = new NativeTaskExecutor(this);
workerContext = new NativeWorkerContext();
objectStore = new NativeObjectStore(workerContext);
taskSubmitter = new NativeTaskSubmitter();
LOGGER.info("RayNativeRuntime started with store {}, raylet {}",
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName);
@@ -108,10 +102,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
@Override
public void shutdown() {
if (nativeCoreWorkerPointer != 0) {
nativeDestroyCoreWorker(nativeCoreWorkerPointer);
nativeCoreWorkerPointer = 0;
}
nativeShutdown();
if (null != manager) {
manager.cleanup();
manager = null;
@@ -131,75 +122,56 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
if (nodeId == null) {
nodeId = UniqueId.NIL;
}
nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes());
nativeSetResource(resourceName, capacity, nodeId.getBytes());
}
@Override
public void killActor(BaseActor actor, boolean noReconstruction) {
nativeKillActor(nativeCoreWorkerPointer, actor.getId().getBytes(), noReconstruction);
nativeKillActor(actor.getId().getBytes(), noReconstruction);
}
@Override
public Object getAsyncContext() {
return null;
return new AsyncContext(workerContext.getCurrentWorkerId(),
workerContext.getCurrentClassLoader());
}
@Override
public void setAsyncContext(Object asyncContext) {
nativeSetCoreWorker(((AsyncContext) asyncContext).workerId.getBytes());
workerContext.setCurrentClassLoader(((AsyncContext) asyncContext).currentClassLoader);
super.setAsyncContext(asyncContext);
}
@Override
public void run() {
nativeRunTaskExecutor(nativeCoreWorkerPointer);
Preconditions.checkState(rayConfig.workerMode == WorkerType.WORKER);
nativeRunTaskExecutor(taskExecutor);
}
public long getNativeCoreWorkerPointer() {
return nativeCoreWorkerPointer;
}
private static native void nativeInitialize(int workerMode, String ndoeIpAddress,
int nodeManagerPort, String driverName, String storeSocket, String rayletSocket,
byte[] jobId, GcsClientOptions gcsClientOptions, int numWorkersPerProcess,
String logDir, Map<String, String> rayletConfigParameters);
public TaskExecutor getTaskExecutor() {
return taskExecutor;
}
private static native void nativeRunTaskExecutor(TaskExecutor taskExecutor);
/**
* Register this worker or driver to GCS.
*/
private void registerWorker() {
RedisClient redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
Map<String, String> workerInfo = new HashMap<>();
String workerId = new String(workerContext.getCurrentWorkerId().getBytes());
if (rayConfig.workerMode == WorkerType.DRIVER) {
workerInfo.put("node_ip_address", rayConfig.nodeIp);
workerInfo.put("driver_id", workerId);
workerInfo.put("start_time", String.valueOf(System.currentTimeMillis()));
workerInfo.put("plasma_store_socket", rayConfig.objectStoreSocketName);
workerInfo.put("raylet_socket", rayConfig.rayletSocketName);
workerInfo.put("name", System.getProperty("user.dir"));
//TODO: worker.redis_client.hmset(b"Drivers:" + worker.workerId, driver_info)
redisClient.hmset("Drivers:" + workerId, workerInfo);
} else {
workerInfo.put("node_ip_address", rayConfig.nodeIp);
workerInfo.put("plasma_store_socket", rayConfig.objectStoreSocketName);
workerInfo.put("raylet_socket", rayConfig.rayletSocketName);
//TODO: b"Workers:" + worker.workerId,
redisClient.hmset("Workers:" + workerId, workerInfo);
private static native void nativeShutdown();
private static native void nativeSetResource(String resourceName, double capacity, byte[] nodeId);
private static native void nativeKillActor(byte[] actorId, boolean noReconstruction);
private static native void nativeSetCoreWorker(byte[] workerId);
static class AsyncContext {
public final UniqueId workerId;
public final ClassLoader currentClassLoader;
AsyncContext(UniqueId workerId, ClassLoader currentClassLoader) {
this.workerId = workerId;
this.currentClassLoader = currentClassLoader;
}
}
private static native long nativeInitCoreWorker(int workerMode, String storeSocket,
String rayletSocket, String nodeIpAddress, int nodeManagerPort, byte[] jobId,
GcsClientOptions gcsClientOptions);
private static native void nativeRunTaskExecutor(long nativeCoreWorkerPointer);
private static native void nativeDestroyCoreWorker(long nativeCoreWorkerPointer);
private static native void nativeSetup(String logDir, Map<String, String> rayletConfigParameters);
private static native void nativeShutdownHook();
private static native void nativeSetResource(long conn, String resourceName, double capacity,
byte[] nodeId);
private static native void nativeKillActor(long nativeCoreWorkerPointer, byte[] actorId,
boolean noReconstruction);
}
@@ -0,0 +1,33 @@
package org.ray.runtime;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.context.WorkerContext;
import org.ray.runtime.functionmanager.FunctionManager;
import org.ray.runtime.gcs.GcsClient;
import org.ray.runtime.object.ObjectStore;
/**
* This interface is required to make {@link RayRuntimeProxy} work.
*/
public interface RayRuntimeInternal extends RayRuntime {
/**
* Start runtime.
*/
void start();
WorkerContext getWorkerContext();
ObjectStore getObjectStore();
FunctionManager getFunctionManager();
RayConfig getRayConfig();
GcsClient getGcsClient();
void setIsContextSet(boolean isContextSet);
void run();
}
@@ -0,0 +1,83 @@
package org.ray.runtime;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import org.ray.api.exception.RayException;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.config.RunMode;
/**
* Protect a ray runtime with context checks for all methods of {@link RayRuntime} (except {@link
* RayRuntime#shutdown(boolean)}).
*/
public class RayRuntimeProxy implements InvocationHandler {
/**
* The original runtime.
*/
private AbstractRayRuntime obj;
private RayRuntimeProxy(AbstractRayRuntime obj) {
this.obj = obj;
}
public AbstractRayRuntime getRuntimeObject() {
return obj;
}
/**
* Generate a new instance of {@link RayRuntimeInternal} with additional context check.
*/
static RayRuntimeInternal newInstance(AbstractRayRuntime obj) {
return (RayRuntimeInternal) java.lang.reflect.Proxy
.newProxyInstance(obj.getClass().getClassLoader(), new Class<?>[]{RayRuntimeInternal.class},
new RayRuntimeProxy(obj));
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if (isInterfaceMethod(method) && !method.getName().equals("shutdown") && !method.getName()
.equals("setAsyncContext")) {
checkIsContextSet();
}
try {
return method.invoke(obj, args);
} catch (InvocationTargetException e) {
if (e.getCause() != null) {
throw e.getCause();
} else {
throw e;
}
}
}
/**
* Whether the method is defined in the {@link RayRuntime} interface.
*/
private boolean isInterfaceMethod(Method method) {
try {
RayRuntime.class.getMethod(method.getName(), method.getParameterTypes());
return true;
} catch (NoSuchMethodException e) {
return false;
}
}
/**
* Check if thread context is set.
* <p/>
* This method should be invoked at the beginning of most public methods of {@link RayRuntime},
* otherwise the native code might crash due to thread local core worker was not set. We check it
* for {@link AbstractRayRuntime} instead of {@link RayNativeRuntime} because we want to catch the
* error even if the application runs in {@link RunMode#SINGLE_PROCESS} mode.
*/
private void checkIsContextSet() {
if (!obj.isContextSet.get()) {
throw new RayException(
"`Ray.wrap***` is not called on the current thread."
+ " If you want to use Ray API in your own threads,"
+ " please wrap your executable with `Ray.wrap***`.");
}
}
}
@@ -7,11 +7,7 @@ import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.List;
import org.ray.api.BaseActor;
import org.ray.api.Ray;
import org.ray.api.id.ActorId;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.RayNativeRuntime;
import org.ray.runtime.generated.Common.Language;
/**
@@ -20,20 +16,17 @@ import org.ray.runtime.generated.Common.Language;
*/
public abstract class NativeRayActor implements BaseActor, Externalizable {
/**
* Address of core worker.
*/
long nativeCoreWorkerPointer;
/**
* ID of the actor.
*/
byte[] actorId;
NativeRayActor(long nativeCoreWorkerPointer, byte[] actorId) {
Preconditions.checkState(nativeCoreWorkerPointer != 0);
private Language language;
NativeRayActor(byte[] actorId, Language language) {
Preconditions.checkState(!ActorId.fromBytes(actorId).isNil());
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
this.actorId = actorId;
this.language = language;
}
/**
@@ -42,14 +35,12 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
NativeRayActor() {
}
public static NativeRayActor create(long nativeCoreWorkerPointer, byte[] actorId,
Language language) {
Preconditions.checkState(nativeCoreWorkerPointer != 0);
public static NativeRayActor create(byte[] actorId, Language language) {
switch (language) {
case JAVA:
return new NativeRayJavaActor(nativeCoreWorkerPointer, actorId);
return new NativeRayJavaActor(actorId);
case PYTHON:
return new NativeRayPyActor(nativeCoreWorkerPointer, actorId);
return new NativeRayPyActor(actorId);
default:
throw new IllegalStateException("Unknown actor handle language: " + language);
}
@@ -61,18 +52,19 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
}
public Language getLanguage() {
return Language.forNumber(nativeGetLanguage(nativeCoreWorkerPointer, actorId));
return language;
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeObject(toBytes());
out.writeObject(nativeSerialize(actorId));
out.writeObject(language);
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
nativeCoreWorkerPointer = getNativeCoreWorkerPointer();
actorId = nativeDeserialize(nativeCoreWorkerPointer, (byte[]) in.readObject());
actorId = nativeDeserialize((byte[]) in.readObject());
language = (Language) in.readObject();
}
/**
@@ -81,7 +73,7 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
* @return the bytes of the actor handle
*/
public byte[] toBytes() {
return nativeSerialize(nativeCoreWorkerPointer, actorId);
return nativeSerialize(actorId);
}
/**
@@ -90,21 +82,10 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
* @return the bytes of an actor handle
*/
public static NativeRayActor fromBytes(byte[] bytes) {
long nativeCoreWorkerPointer = getNativeCoreWorkerPointer();
byte[] actorId = nativeDeserialize(nativeCoreWorkerPointer, bytes);
Language language = Language.forNumber(nativeGetLanguage(nativeCoreWorkerPointer, actorId));
byte[] actorId = nativeDeserialize(bytes);
Language language = Language.forNumber(nativeGetLanguage(actorId));
Preconditions.checkNotNull(language);
return create(nativeCoreWorkerPointer, actorId, language);
}
private static long getNativeCoreWorkerPointer() {
RayRuntime runtime = Ray.internal();
if (runtime instanceof RayMultiWorkerNativeRuntime) {
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
}
Preconditions.checkState(runtime instanceof RayNativeRuntime);
return ((RayNativeRuntime) runtime).getNativeCoreWorkerPointer();
return create(actorId, language);
}
@Override
@@ -112,13 +93,11 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
// TODO(zhijunfu): do we need to free the ActorHandle in core worker?
}
private static native int nativeGetLanguage(
long nativeCoreWorkerPointer, byte[] actorId);
private static native int nativeGetLanguage(byte[] actorId);
static native List<String> nativeGetActorCreationTaskFunctionDescriptor(
long nativeCoreWorkerPointer, byte[] actorId);
static native List<String> nativeGetActorCreationTaskFunctionDescriptor(byte[] actorId);
private static native byte[] nativeSerialize(long nativeCoreWorkerPointer, byte[] actorId);
private static native byte[] nativeSerialize(byte[] actorId);
private static native byte[] nativeDeserialize(long nativeCoreWorkerPointer, byte[] data);
private static native byte[] nativeDeserialize(byte[] data);
}
@@ -11,8 +11,8 @@ import org.ray.runtime.generated.Common.Language;
*/
public class NativeRayJavaActor extends NativeRayActor implements RayActor {
NativeRayJavaActor(long nativeCoreWorkerPointer, byte[] actorId) {
super(nativeCoreWorkerPointer, actorId);
NativeRayJavaActor(byte[] actorId) {
super(actorId, Language.JAVA);
}
/**
@@ -11,8 +11,8 @@ import org.ray.runtime.generated.Common.Language;
*/
public class NativeRayPyActor extends NativeRayActor implements RayPyActor {
NativeRayPyActor(long nativeCoreWorkerPointer, byte[] actorId) {
super(nativeCoreWorkerPointer, actorId);
NativeRayPyActor(byte[] actorId) {
super(actorId, Language.PYTHON);
}
/**
@@ -24,12 +24,12 @@ public class NativeRayPyActor extends NativeRayActor implements RayPyActor {
@Override
public String getModuleName() {
return nativeGetActorCreationTaskFunctionDescriptor(nativeCoreWorkerPointer, actorId).get(0);
return nativeGetActorCreationTaskFunctionDescriptor(actorId).get(0);
}
@Override
public String getClassName() {
return nativeGetActorCreationTaskFunctionDescriptor(nativeCoreWorkerPointer, actorId).get(1);
return nativeGetActorCreationTaskFunctionDescriptor(actorId).get(1);
}
@Override
@@ -93,10 +93,6 @@ public class RayConfig {
}
}
/**
* Number of threads that execute tasks.
*/
public final int numberExecThreadsForDevRuntime;
public final int numWorkersPerProcess;
@@ -220,9 +216,6 @@ public class RayConfig {
jobResourcePath = null;
}
// Number of threads that execute tasks.
numberExecThreadsForDevRuntime = config.getInt("ray.dev-runtime.execution-parallelism");
numWorkersPerProcess = config.getInt("ray.raylet.config.num_workers_per_process_java");
gcsServiceEnabled = System.getenv("RAY_GCS_SERVICE_ENABLED") == null ||
@@ -16,6 +16,7 @@ public class LocalModeWorkerContext implements WorkerContext {
private final JobId jobId;
private ThreadLocal<TaskSpec> currentTask = new ThreadLocal<>();
private final ThreadLocal<UniqueId> currentWorkerId = new ThreadLocal<>();
public LocalModeWorkerContext(JobId jobId) {
this.jobId = jobId;
@@ -23,7 +24,11 @@ public class LocalModeWorkerContext implements WorkerContext {
@Override
public UniqueId getCurrentWorkerId() {
throw new UnsupportedOperationException();
return currentWorkerId.get();
}
public void setCurrentWorkerId(UniqueId workerId) {
currentWorkerId.set(workerId);
}
@Override
@@ -12,61 +12,52 @@ import org.ray.runtime.generated.Common.TaskType;
*/
public class NativeWorkerContext implements WorkerContext {
/**
* The native pointer of core worker.
*/
private final long nativeCoreWorkerPointer;
private ClassLoader currentClassLoader;
public NativeWorkerContext(long nativeCoreWorkerPointer) {
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
}
private final ThreadLocal<ClassLoader> currentClassLoader = new ThreadLocal<>();
@Override
public UniqueId getCurrentWorkerId() {
return UniqueId.fromByteBuffer(nativeGetCurrentWorkerId(nativeCoreWorkerPointer));
return UniqueId.fromByteBuffer(nativeGetCurrentWorkerId());
}
@Override
public JobId getCurrentJobId() {
return JobId.fromByteBuffer(nativeGetCurrentJobId(nativeCoreWorkerPointer));
return JobId.fromByteBuffer(nativeGetCurrentJobId());
}
@Override
public ActorId getCurrentActorId() {
return ActorId.fromByteBuffer(nativeGetCurrentActorId(nativeCoreWorkerPointer));
return ActorId.fromByteBuffer(nativeGetCurrentActorId());
}
@Override
public ClassLoader getCurrentClassLoader() {
return currentClassLoader;
return currentClassLoader.get();
}
@Override
public void setCurrentClassLoader(ClassLoader currentClassLoader) {
if (this.currentClassLoader != currentClassLoader) {
this.currentClassLoader = currentClassLoader;
if (this.currentClassLoader.get() != currentClassLoader) {
this.currentClassLoader.set(currentClassLoader);
}
}
@Override
public TaskType getCurrentTaskType() {
return TaskType.forNumber(nativeGetCurrentTaskType(nativeCoreWorkerPointer));
return TaskType.forNumber(nativeGetCurrentTaskType());
}
@Override
public TaskId getCurrentTaskId() {
return TaskId.fromByteBuffer(nativeGetCurrentTaskId(nativeCoreWorkerPointer));
return TaskId.fromByteBuffer(nativeGetCurrentTaskId());
}
private static native int nativeGetCurrentTaskType(long nativeCoreWorkerPointer);
private static native int nativeGetCurrentTaskType();
private static native ByteBuffer nativeGetCurrentTaskId(long nativeCoreWorkerPointer);
private static native ByteBuffer nativeGetCurrentTaskId();
private static native ByteBuffer nativeGetCurrentJobId(long nativeCoreWorkerPointer);
private static native ByteBuffer nativeGetCurrentJobId();
private static native ByteBuffer nativeGetCurrentWorkerId(long nativeCoreWorkerPointer);
private static native ByteBuffer nativeGetCurrentWorkerId();
private static native ByteBuffer nativeGetCurrentActorId(long nativeCoreWorkerPointer);
private static native ByteBuffer nativeGetCurrentActorId();
}
@@ -6,15 +6,15 @@ import org.ray.api.id.ActorId;
import org.ray.api.id.JobId;
import org.ray.api.runtimecontext.NodeInfo;
import org.ray.api.runtimecontext.RuntimeContext;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayRuntimeInternal;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.generated.Common.TaskType;
public class RuntimeContextImpl implements RuntimeContext {
private AbstractRayRuntime runtime;
private RayRuntimeInternal runtime;
public RuntimeContextImpl(AbstractRayRuntime runtime) {
public RuntimeContextImpl(RayRuntimeInternal runtime) {
this.runtime = runtime;
}
@@ -15,56 +15,48 @@ public class NativeObjectStore extends ObjectStore {
private static final Logger LOGGER = LoggerFactory.getLogger(NativeObjectStore.class);
/**
* The native pointer of core worker.
*/
private final long nativeCoreWorkerPointer;
public NativeObjectStore(WorkerContext workerContext, long nativeCoreWorkerPointer) {
public NativeObjectStore(WorkerContext workerContext) {
super(workerContext);
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
}
@Override
public ObjectId putRaw(NativeRayObject obj) {
return new ObjectId(nativePut(nativeCoreWorkerPointer, obj));
return new ObjectId(nativePut(obj));
}
@Override
public void putRaw(NativeRayObject obj, ObjectId objectId) {
nativePut(nativeCoreWorkerPointer, objectId.getBytes(), obj);
nativePut(objectId.getBytes(), obj);
}
@Override
public List<NativeRayObject> getRaw(List<ObjectId> objectIds, long timeoutMs) {
return nativeGet(nativeCoreWorkerPointer, toBinaryList(objectIds), timeoutMs);
return nativeGet(toBinaryList(objectIds), timeoutMs);
}
@Override
public List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
return nativeWait(nativeCoreWorkerPointer, toBinaryList(objectIds), numObjects, timeoutMs);
return nativeWait(toBinaryList(objectIds), numObjects, timeoutMs);
}
@Override
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
nativeDelete(nativeCoreWorkerPointer, toBinaryList(objectIds), localOnly, deleteCreatingTasks);
nativeDelete(toBinaryList(objectIds), localOnly, deleteCreatingTasks);
}
private static List<byte[]> toBinaryList(List<ObjectId> ids) {
return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
}
private static native byte[] nativePut(long nativeCoreWorkerPointer, NativeRayObject obj);
private static native byte[] nativePut(NativeRayObject obj);
private static native void nativePut(long nativeCoreWorkerPointer, byte[] objectId,
NativeRayObject obj);
private static native void nativePut(byte[] objectId, NativeRayObject obj);
private static native List<NativeRayObject> nativeGet(long nativeCoreWorkerPointer,
List<byte[]> ids, long timeoutMs);
private static native List<NativeRayObject> nativeGet(List<byte[]> ids, long timeoutMs);
private static native List<Boolean> nativeWait(long nativeCoreWorkerPointer,
List<byte[]> objectIds, int numObjects, long timeoutMs);
private static native List<Boolean> nativeWait(List<byte[]> objectIds, int numObjects,
long timeoutMs);
private static native void nativeDelete(long nativeCoreWorkerPointer, List<byte[]> objectIds,
boolean localOnly, boolean deleteCreatingTasks);
private static native void nativeDelete(List<byte[]> objectIds, boolean localOnly,
boolean deleteCreatingTasks);
}
@@ -1,9 +1,7 @@
package org.ray.runtime.runner.worker;
import org.ray.api.Ray;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.RayNativeRuntime;
import org.ray.runtime.RayRuntimeInternal;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -25,14 +23,7 @@ public class DefaultWorker {
});
Ray.init();
LOGGER.info("Worker started.");
RayRuntime runtime = Ray.internal();
if (runtime instanceof RayNativeRuntime) {
((RayNativeRuntime)runtime).run();
} else if (runtime instanceof RayMultiWorkerNativeRuntime) {
((RayMultiWorkerNativeRuntime)runtime).run();
} else {
throw new RuntimeException("Unknown RayRuntime: " + runtime);
}
((RayRuntimeInternal) Ray.internal()).run();
} catch (Exception e) {
LOGGER.error("Failed to start worker.", e);
}
@@ -6,9 +6,7 @@ import java.util.List;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.id.ObjectId;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayMultiWorkerNativeRuntime;
import org.ray.runtime.RayRuntimeInternal;
import org.ray.runtime.generated.Common.Language;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.object.ObjectSerializer;
@@ -43,11 +41,7 @@ public class ArgumentsBuilder {
} else {
value = ObjectSerializer.serialize(arg);
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
RayRuntime runtime = Ray.internal();
if (runtime instanceof RayMultiWorkerNativeRuntime) {
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
}
id = ((AbstractRayRuntime) runtime).getObjectStore()
id = ((RayRuntimeInternal) Ray.internal()).getObjectStore()
.putRaw(value);
value = null;
}
@@ -1,17 +1,40 @@
package org.ray.runtime.task;
import org.ray.api.id.ActorId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.api.id.UniqueId;
import org.ray.runtime.RayRuntimeInternal;
import org.ray.runtime.task.LocalModeTaskExecutor.LocalActorContext;
/**
* Task executor for local mode.
*/
public class LocalModeTaskExecutor extends TaskExecutor {
public class LocalModeTaskExecutor extends TaskExecutor<LocalActorContext> {
public LocalModeTaskExecutor(AbstractRayRuntime runtime) {
static class LocalActorContext extends TaskExecutor.ActorContext {
/**
* The worker ID of the actor.
*/
private final UniqueId workerId;
public LocalActorContext(UniqueId workerId) {
this.workerId = workerId;
}
public UniqueId getWorkerId() {
return workerId;
}
}
public LocalModeTaskExecutor(RayRuntimeInternal runtime) {
super(runtime);
}
@Override
protected LocalActorContext createActorContext() {
return new LocalActorContext(runtime.getWorkerContext().getCurrentWorkerId());
}
@Override
protected void maybeSaveCheckpoint(Object actor, ActorId actorId) {
}
@@ -4,16 +4,15 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import java.nio.ByteBuffer;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
@@ -21,9 +20,10 @@ import org.ray.api.BaseActor;
import org.ray.api.id.ActorId;
import org.ray.api.id.ObjectId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.CallOptions;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.RayRuntimeInternal;
import org.ray.runtime.actor.LocalModeRayActor;
import org.ray.runtime.context.LocalModeWorkerContext;
import org.ray.runtime.functionmanager.FunctionDescriptor;
@@ -37,6 +37,7 @@ import org.ray.runtime.generated.Common.TaskSpec;
import org.ray.runtime.generated.Common.TaskType;
import org.ray.runtime.object.LocalModeObjectStore;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.task.TaskExecutor.ActorContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -49,7 +50,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
private final Map<ObjectId, Set<TaskSpec>> waitingTasks = new HashMap<>();
private final Object taskAndObjectLock = new Object();
private final RayDevRuntime runtime;
private final RayRuntimeInternal runtime;
private final TaskExecutor taskExecutor;
private final LocalModeObjectStore objectStore;
/// The thread pool to execute actor tasks.
@@ -58,17 +60,16 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
/// The thread pool to execute normal tasks.
private final ExecutorService normalTaskExecutorService;
private final Deque<TaskExecutor> idleTaskExecutors = new ArrayDeque<>();
private final Map<ActorId, TaskExecutor> actorTaskExecutors = new HashMap<>();
private final Object taskExecutorLock = new Object();
private final ThreadLocal<TaskExecutor> currentTaskExecutor = new ThreadLocal<>();
public LocalModeTaskSubmitter(RayDevRuntime runtime, LocalModeObjectStore objectStore,
int numberThreads) {
private final Map<ActorId, ActorContext> actorContexts = new ConcurrentHashMap<>();
public LocalModeTaskSubmitter(RayRuntimeInternal runtime, TaskExecutor taskExecutor,
LocalModeObjectStore objectStore) {
this.runtime = runtime;
this.taskExecutor = taskExecutor;
this.objectStore = objectStore;
// The thread pool that executes normal tasks in parallel.
normalTaskExecutorService = Executors.newFixedThreadPool(numberThreads);
normalTaskExecutorService = Executors.newCachedThreadPool();
// The thread pool that executes actor tasks in parallel.
actorTaskExecutorServices = new HashMap<>();
}
@@ -88,46 +89,6 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
}
}
/**
* Get the worker of current thread. <br> NOTE: Cannot be used for multi-threading in worker.
*/
public TaskExecutor getCurrentTaskExecutor() {
return currentTaskExecutor.get();
}
/**
* Get a worker from the worker pool to run the given task.
*/
private TaskExecutor getTaskExecutor(TaskSpec task) {
TaskExecutor taskExecutor;
synchronized (taskExecutorLock) {
if (task.getType() == TaskType.ACTOR_TASK) {
taskExecutor = actorTaskExecutors.get(getActorId(task));
} else if (task.getType() == TaskType.ACTOR_CREATION_TASK) {
taskExecutor = new LocalModeTaskExecutor(runtime);
actorTaskExecutors.put(getActorId(task), taskExecutor);
} else if (idleTaskExecutors.size() > 0) {
taskExecutor = idleTaskExecutors.pop();
} else {
taskExecutor = new LocalModeTaskExecutor(runtime);
}
}
currentTaskExecutor.set(taskExecutor);
return taskExecutor;
}
/**
* Return the worker to the worker pool.
*/
private void returnTaskExecutor(TaskExecutor worker, TaskSpec taskSpec) {
currentTaskExecutor.remove();
synchronized (taskExecutorLock) {
if (taskSpec.getType() == TaskType.NORMAL_TASK) {
idleTaskExecutors.push(worker);
}
}
}
private Set<ObjectId> getUnreadyObjects(TaskSpec taskSpec) {
Set<ObjectId> unreadyObjects = new HashSet<>();
// Check whether task arguments are ready.
@@ -257,32 +218,11 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
Set<ObjectId> unreadyObjects = getUnreadyObjects(taskSpec);
final Runnable runnable = () -> {
TaskExecutor taskExecutor = getTaskExecutor(taskSpec);
try {
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
.map(arg -> arg.id != null ?
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
: arg.value)
.collect(Collectors.toList());
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
List<NativeRayObject> returnObjects = taskExecutor
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null);
List<ObjectId> returnIds = getReturnIds(taskSpec);
for (int i = 0; i < returnIds.size(); i++) {
NativeRayObject putObject;
if (i >= returnObjects.size()) {
// If the task is an actor task or an actor creation task,
// put the dummy object in object store, so those tasks which depends on it
// can be executed.
putObject = new NativeRayObject(new byte[] {1}, null);
} else {
putObject = returnObjects.get(i);
}
objectStore.putRaw(putObject, returnIds.get(i));
}
} finally {
returnTaskExecutor(taskExecutor, taskSpec);
executeTask(taskSpec);
} catch (Exception ex) {
LOGGER.error("Unexpected exception when executing a task.", ex);
System.exit(-1);
}
};
@@ -313,6 +253,52 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
}
}
private void executeTask(TaskSpec taskSpec) {
ActorContext actorContext = null;
if (taskSpec.getType() == TaskType.ACTOR_TASK) {
actorContext = actorContexts.get(getActorId(taskSpec));
Preconditions.checkNotNull(actorContext);
}
taskExecutor.setActorContext(actorContext);
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
.map(arg -> arg.id != null ?
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
: arg.value)
.collect(Collectors.toList());
runtime.setIsContextSet(true);
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
UniqueId workerId = actorContext != null
? ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId()
: UniqueId.randomId();
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId);
List<NativeRayObject> returnObjects = taskExecutor
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
// Update actor context map ASAP in case objectStore.putRaw triggered the next actor task
// on this actor.
actorContexts.put(getActorId(taskSpec), taskExecutor.getActorContext());
}
// Set this flag to true is necessary because at the end of `taskExecutor.execute()`,
// this flag will be set to false. And `runtime.getWorkerContext()` requires it to be
// true.
runtime.setIsContextSet(true);
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null);
runtime.setIsContextSet(false);
List<ObjectId> returnIds = getReturnIds(taskSpec);
for (int i = 0; i < returnIds.size(); i++) {
NativeRayObject putObject;
if (i >= returnObjects.size()) {
// If the task is an actor task or an actor creation task,
// put the dummy object in object store, so those tasks which depends on it
// can be executed.
putObject = new NativeRayObject(new byte[]{1}, null);
} else {
putObject = returnObjects.get(i);
}
objectStore.putRaw(putObject, returnIds.get(i));
}
}
private static JavaFunctionDescriptor getJavaFunctionDescriptor(TaskSpec taskSpec) {
org.ray.runtime.generated.Common.FunctionDescriptor functionDescriptor =
taskSpec.getFunctionDescriptor();
@@ -8,39 +8,42 @@ import org.ray.api.Checkpointable.Checkpoint;
import org.ray.api.Checkpointable.CheckpointContext;
import org.ray.api.id.ActorId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayRuntimeInternal;
import org.ray.runtime.task.NativeTaskExecutor.NativeActorContext;
/**
* Task executor for cluster mode.
*/
public class NativeTaskExecutor extends TaskExecutor {
public class NativeTaskExecutor extends TaskExecutor<NativeActorContext> {
// TODO(hchen): Use the C++ config.
private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20;
/**
* The native pointer of core worker.
*/
private final long nativeCoreWorkerPointer;
static class NativeActorContext extends TaskExecutor.ActorContext {
/**
* Number of tasks executed since last actor checkpoint.
*/
private int numTasksSinceLastCheckpoint = 0;
/**
* Number of tasks executed since last actor checkpoint.
*/
private int numTasksSinceLastCheckpoint = 0;
/**
* IDs of this actor's previous checkpoints.
*/
private List<UniqueId> checkpointIds;
/**
* IDs of this actor's previous checkpoints.
*/
private List<UniqueId> checkpointIds;
/**
* Timestamp of the last actor checkpoint.
*/
private long lastCheckpointTimestamp = 0;
/**
* Timestamp of the last actor checkpoint.
*/
private long lastCheckpointTimestamp = 0;
}
public NativeTaskExecutor(long nativeCoreWorkerPointer, AbstractRayRuntime runtime) {
public NativeTaskExecutor(RayRuntimeInternal runtime) {
super(runtime);
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
}
@Override
protected NativeActorContext createActorContext() {
return new NativeActorContext();
}
@Override
@@ -48,15 +51,18 @@ public class NativeTaskExecutor extends TaskExecutor {
if (!(actor instanceof Checkpointable)) {
return;
}
NativeActorContext actorContext = getActorContext();
CheckpointContext checkpointContext = new CheckpointContext(actorId,
++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp);
++actorContext.numTasksSinceLastCheckpoint,
System.currentTimeMillis() - actorContext.lastCheckpointTimestamp);
Checkpointable checkpointable = (Checkpointable) actor;
if (!checkpointable.shouldCheckpoint(checkpointContext)) {
return;
}
numTasksSinceLastCheckpoint = 0;
lastCheckpointTimestamp = System.currentTimeMillis();
UniqueId checkpointId = new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer));
actorContext.numTasksSinceLastCheckpoint = 0;
actorContext.lastCheckpointTimestamp = System.currentTimeMillis();
UniqueId checkpointId = new UniqueId(nativePrepareCheckpoint());
List<UniqueId> checkpointIds = actorContext.checkpointIds;
checkpointIds.add(checkpointId);
if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) {
((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0));
@@ -70,9 +76,10 @@ public class NativeTaskExecutor extends TaskExecutor {
if (!(actor instanceof Checkpointable)) {
return;
}
numTasksSinceLastCheckpoint = 0;
lastCheckpointTimestamp = System.currentTimeMillis();
checkpointIds = new ArrayList<>();
NativeActorContext actorContext = getActorContext();
actorContext.numTasksSinceLastCheckpoint = 0;
actorContext.lastCheckpointTimestamp = System.currentTimeMillis();
actorContext.checkpointIds = new ArrayList<>();
List<Checkpoint> availableCheckpoints
= runtime.getGcsClient().getCheckpointsForActor(actorId);
if (availableCheckpoints.isEmpty()) {
@@ -90,13 +97,11 @@ public class NativeTaskExecutor extends TaskExecutor {
Preconditions.checkArgument(checkpointValid,
"'loadCheckpoint' must return a checkpoint ID that exists in the "
+ "'availableCheckpoints' list, or null.");
nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, checkpointId.getBytes());
nativeNotifyActorResumedFromCheckpoint(checkpointId.getBytes());
}
}
private static native byte[] nativePrepareCheckpoint(long nativeCoreWorkerPointer);
private static native byte[] nativePrepareCheckpoint();
private static native void nativeNotifyActorResumedFromCheckpoint(long nativeCoreWorkerPointer,
byte[] checkpointId);
private static native void nativeNotifyActorResumedFromCheckpoint(byte[] checkpointId);
}
@@ -15,30 +15,18 @@ import org.ray.runtime.functionmanager.FunctionDescriptor;
*/
public class NativeTaskSubmitter implements TaskSubmitter {
/**
* The native pointer of core worker.
*/
private final long nativeCoreWorkerPointer;
public NativeTaskSubmitter(long nativeCoreWorkerPointer) {
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
}
@Override
public List<ObjectId> submitTask(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
int numReturns, CallOptions options) {
List<byte[]> returnIds = nativeSubmitTask(nativeCoreWorkerPointer, functionDescriptor, args,
numReturns, options);
List<byte[]> returnIds = nativeSubmitTask(functionDescriptor, args, numReturns, options);
return returnIds.stream().map(ObjectId::new).collect(Collectors.toList());
}
@Override
public BaseActor createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
ActorCreationOptions options) {
byte[] actorId = nativeCreateActor(nativeCoreWorkerPointer, functionDescriptor, args,
options);
return NativeRayActor.create(nativeCoreWorkerPointer, actorId,
functionDescriptor.getLanguage());
byte[] actorId = nativeCreateActor(functionDescriptor, args, options);
return NativeRayActor.create(actorId, functionDescriptor.getLanguage());
}
@Override
@@ -46,24 +34,18 @@ public class NativeTaskSubmitter implements TaskSubmitter {
BaseActor actor, FunctionDescriptor functionDescriptor,
List<FunctionArg> args, int numReturns, CallOptions options) {
Preconditions.checkState(actor instanceof NativeRayActor);
List<byte[]> returnIds = nativeSubmitActorTask(nativeCoreWorkerPointer,
actor.getId().getBytes(), functionDescriptor, args, numReturns,
options);
List<byte[]> returnIds = nativeSubmitActorTask(actor.getId().getBytes(),
functionDescriptor, args, numReturns, options);
return returnIds.stream().map(ObjectId::new).collect(Collectors.toList());
}
private static native List<byte[]> nativeSubmitTask(
long nativeCoreWorkerPointer,
private static native List<byte[]> nativeSubmitTask(FunctionDescriptor functionDescriptor,
List<FunctionArg> args, int numReturns, CallOptions callOptions);
private static native byte[] nativeCreateActor(FunctionDescriptor functionDescriptor,
List<FunctionArg> args, ActorCreationOptions actorCreationOptions);
private static native List<byte[]> nativeSubmitActorTask(byte[] actorId,
FunctionDescriptor functionDescriptor, List<FunctionArg> args, int numReturns,
CallOptions callOptions);
private static native byte[] nativeCreateActor(
long nativeCoreWorkerPointer,
FunctionDescriptor functionDescriptor, List<FunctionArg> args,
ActorCreationOptions actorCreationOptions);
private static native List<byte[]> nativeSubmitActorTask(
long nativeCoreWorkerPointer,
byte[] actorId, FunctionDescriptor functionDescriptor, List<FunctionArg> args,
int numReturns, CallOptions callOptions);
}
@@ -1,6 +1,7 @@
package org.ray.runtime.task;
import com.google.common.base.Preconditions;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
@@ -9,58 +10,75 @@ import org.ray.api.id.ActorId;
import org.ray.api.id.JobId;
import org.ray.api.id.TaskId;
import org.ray.api.id.UniqueId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.RayRuntimeInternal;
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
import org.ray.runtime.functionmanager.RayFunction;
import org.ray.runtime.generated.Common.TaskType;
import org.ray.runtime.object.NativeRayObject;
import org.ray.runtime.object.ObjectSerializer;
import org.ray.runtime.task.TaskExecutor.ActorContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* The task executor, which executes tasks assigned by raylet continuously.
*/
public abstract class TaskExecutor {
public abstract class TaskExecutor<T extends ActorContext> {
private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class);
// A helper map to help we get the corresponding executor for the given worker in JNI.
private static ConcurrentHashMap<UniqueId, TaskExecutor> taskExecutors
= new ConcurrentHashMap<>();
protected final RayRuntimeInternal runtime;
protected final AbstractRayRuntime runtime;
private final ConcurrentHashMap<UniqueId, T> actorContextMap = new ConcurrentHashMap<>();
/**
* The current actor object, if this worker is an actor, otherwise null.
*/
protected Object currentActor = null;
static class ActorContext {
/**
* The exception that failed the actor creation task, if any.
*/
private Exception actorCreationException = null;
/**
* The current actor object, if this worker is an actor, otherwise null.
*/
Object currentActor = null;
protected TaskExecutor(AbstractRayRuntime runtime) {
this.runtime = runtime;
if (RayConfig.getInstance().runMode == RunMode.CLUSTER) {
taskExecutors.put(runtime.getWorkerContext().getCurrentWorkerId(), this);
}
/**
* The exception that failed the actor creation task, if any.
*/
Throwable actorCreationException = null;
}
public static TaskExecutor get(byte[] workerId) {
return taskExecutors.get(new UniqueId(workerId));
TaskExecutor(RayRuntimeInternal runtime) {
this.runtime = runtime;
}
protected abstract T createActorContext();
T getActorContext() {
return actorContextMap.get(runtime.getWorkerContext().getCurrentWorkerId());
}
void setActorContext(T actorContext) {
if (actorContext == null) {
// ConcurrentHashMap doesn't allow null values. So just return here.
return;
}
this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext);
}
protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
List<NativeRayObject> argsBytes) {
runtime.setIsContextSet(true);
JobId jobId = runtime.getWorkerContext().getCurrentJobId();
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
TaskId taskId = runtime.getWorkerContext().getCurrentTaskId();
LOGGER.debug("Executing task {}", taskId);
T actorContext = null;
if (taskType == TaskType.ACTOR_CREATION_TASK) {
actorContext = createActorContext();
setActorContext(actorContext);
} else if (taskType == TaskType.ACTOR_TASK) {
actorContext = getActorContext();
Preconditions.checkNotNull(actorContext);
}
List<NativeRayObject> returnObjects = new ArrayList<>();
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
@@ -74,19 +92,26 @@ public abstract class TaskExecutor {
// Get local actor object and arguments.
Object actor = null;
if (taskType == TaskType.ACTOR_TASK) {
if (actorCreationException != null) {
throw actorCreationException;
if (actorContext.actorCreationException != null) {
throw actorContext.actorCreationException;
}
actor = currentActor;
actor = actorContext.currentActor;
}
Object[] args = ArgumentsBuilder.unwrap(argsBytes, rayFunction.classLoader);
// Execute the task.
Object result;
if (!rayFunction.isConstructor()) {
result = rayFunction.getMethod().invoke(actor, args);
} else {
result = rayFunction.getConstructor().newInstance(args);
try {
if (!rayFunction.isConstructor()) {
result = rayFunction.getMethod().invoke(actor, args);
} else {
result = rayFunction.getConstructor().newInstance(args);
}
} catch (InvocationTargetException e) {
if (e.getCause() != null) {
throw e.getCause();
} else {
throw e;
}
}
// Set result
if (taskType != TaskType.ACTOR_CREATION_TASK) {
@@ -100,10 +125,10 @@ public abstract class TaskExecutor {
} else {
// TODO (kfstorm): handle checkpoint in core worker.
maybeLoadCheckpoint(result, runtime.getWorkerContext().getCurrentActorId());
currentActor = result;
actorContext.currentActor = result;
}
LOGGER.debug("Finished executing task {}", taskId);
} catch (Exception e) {
} catch (Throwable e) {
LOGGER.error("Error executing task " + taskId, e);
if (taskType != TaskType.ACTOR_CREATION_TASK) {
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
@@ -113,11 +138,12 @@ public abstract class TaskExecutor {
.serialize(new RayTaskException("Error executing task " + taskId, e)));
}
} else {
actorCreationException = e;
actorContext.actorCreationException = e;
}
} finally {
Thread.currentThread().setContextClassLoader(oldLoader);
runtime.getWorkerContext().setCurrentClassLoader(null);
runtime.setIsContextSet(false);
}
return returnObjects;
}
@@ -98,12 +98,6 @@ ray {
}
}
// ----------------------------
// configurations under SINGLE_PROCESS mode
// ----------------------------
dev-runtime {
// Number of threads that you process tasks
execution-parallelism: 10
}
// Whether we enable job manager to submit and manage job.
enable-job-manager: false
}