mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 05:33:18 +08:00
Support multiple core workers in one process (#7623)
This commit is contained in:
@@ -19,7 +19,6 @@ import org.ray.api.function.RayFuncVoid;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.api.runtimecontext.RuntimeContext;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.RuntimeContextImpl;
|
||||
@@ -29,6 +28,7 @@ import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
|
||||
import org.ray.runtime.gcs.GcsClient;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.ray.runtime.object.ObjectStore;
|
||||
import org.ray.runtime.object.RayObjectImpl;
|
||||
import org.ray.runtime.task.ArgumentsBuilder;
|
||||
@@ -41,7 +41,7 @@ import org.slf4j.LoggerFactory;
|
||||
/**
|
||||
* Core functionality to implement Ray APIs.
|
||||
*/
|
||||
public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
|
||||
public static final String PYTHON_INIT_METHOD_NAME = "__init__";
|
||||
@@ -55,9 +55,15 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
protected TaskSubmitter taskSubmitter;
|
||||
protected WorkerContext workerContext;
|
||||
|
||||
public AbstractRayRuntime(RayConfig rayConfig, FunctionManager functionManager) {
|
||||
/**
|
||||
* Whether the required thread context is set on the current thread.
|
||||
*/
|
||||
final ThreadLocal<Boolean> isContextSet = ThreadLocal.withInitial(() -> false);
|
||||
|
||||
public AbstractRayRuntime(RayConfig rayConfig) {
|
||||
this.rayConfig = rayConfig;
|
||||
this.functionManager = functionManager;
|
||||
setIsContextSet(rayConfig.workerMode == WorkerType.DRIVER);
|
||||
functionManager = new FunctionManager(rayConfig.jobResourcePath);
|
||||
runtimeContext = new RuntimeContextImpl(this);
|
||||
}
|
||||
|
||||
@@ -161,13 +167,54 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
}
|
||||
|
||||
@Override
|
||||
public Runnable wrapRunnable(Runnable runnable) {
|
||||
return runnable;
|
||||
public void setAsyncContext(Object asyncContext) {
|
||||
isContextSet.set(true);
|
||||
}
|
||||
|
||||
// TODO (kfstorm): Simplify the duplicate code in wrap*** methods.
|
||||
|
||||
@Override
|
||||
public final Runnable wrapRunnable(Runnable runnable) {
|
||||
Object asyncContext = getAsyncContext();
|
||||
return () -> {
|
||||
boolean oldIsContextSet = isContextSet.get();
|
||||
Object oldAsyncContext = null;
|
||||
if (oldIsContextSet) {
|
||||
oldAsyncContext = getAsyncContext();
|
||||
}
|
||||
setAsyncContext(asyncContext);
|
||||
try {
|
||||
runnable.run();
|
||||
} finally {
|
||||
if (oldIsContextSet) {
|
||||
setAsyncContext(oldAsyncContext);
|
||||
} else {
|
||||
setIsContextSet(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public Callable wrapCallable(Callable callable) {
|
||||
return callable;
|
||||
public final <T> Callable<T> wrapCallable(Callable<T> callable) {
|
||||
Object asyncContext = getAsyncContext();
|
||||
return () -> {
|
||||
boolean oldIsContextSet = isContextSet.get();
|
||||
Object oldAsyncContext = null;
|
||||
if (oldIsContextSet) {
|
||||
oldAsyncContext = getAsyncContext();
|
||||
}
|
||||
setAsyncContext(asyncContext);
|
||||
try {
|
||||
return callable.call();
|
||||
} finally {
|
||||
if (oldIsContextSet) {
|
||||
setAsyncContext(oldAsyncContext);
|
||||
} else {
|
||||
setIsContextSet(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
|
||||
@@ -209,18 +256,22 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
return actor;
|
||||
}
|
||||
|
||||
@Override
|
||||
public WorkerContext getWorkerContext() {
|
||||
return workerContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObjectStore getObjectStore() {
|
||||
return objectStore;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FunctionManager getFunctionManager() {
|
||||
return functionManager;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayConfig getRayConfig() {
|
||||
return rayConfig;
|
||||
}
|
||||
@@ -229,7 +280,13 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
return runtimeContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
public GcsClient getGcsClient() {
|
||||
return gcsClient;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setIsContextSet(boolean isContextSet) {
|
||||
this.isContextSet.set(isContextSet);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,8 +4,6 @@ import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.api.runtime.RayRuntimeFactory;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -20,17 +18,13 @@ public class DefaultRayRuntimeFactory implements RayRuntimeFactory {
|
||||
public RayRuntime createRayRuntime() {
|
||||
RayConfig rayConfig = RayConfig.getInstance();
|
||||
try {
|
||||
FunctionManager functionManager = new FunctionManager(rayConfig.jobResourcePath);
|
||||
RayRuntime runtime;
|
||||
if (rayConfig.runMode == RunMode.SINGLE_PROCESS) {
|
||||
runtime = new RayDevRuntime(rayConfig, functionManager);
|
||||
} else {
|
||||
if (rayConfig.workerMode == WorkerType.DRIVER) {
|
||||
runtime = new RayNativeRuntime(rayConfig, functionManager);
|
||||
} else {
|
||||
runtime = new RayMultiWorkerNativeRuntime(rayConfig, functionManager);
|
||||
}
|
||||
}
|
||||
AbstractRayRuntime innerRuntime = rayConfig.runMode == RunMode.SINGLE_PROCESS
|
||||
? new RayDevRuntime(rayConfig)
|
||||
: new RayNativeRuntime(rayConfig);
|
||||
RayRuntimeInternal runtime = rayConfig.numWorkersPerProcess > 1
|
||||
? RayRuntimeProxy.newInstance(innerRuntime)
|
||||
: innerRuntime;
|
||||
runtime.start();
|
||||
return runtime;
|
||||
} catch (Exception e) {
|
||||
LOGGER.error("Failed to initialize ray runtime", e);
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.ray.api.BaseActor;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.LocalModeWorkerContext;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.object.LocalModeObjectStore;
|
||||
import org.ray.runtime.task.LocalModeTaskExecutor;
|
||||
import org.ray.runtime.task.LocalModeTaskSubmitter;
|
||||
@@ -19,18 +19,31 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
||||
|
||||
private AtomicInteger jobCounter = new AtomicInteger(0);
|
||||
|
||||
public RayDevRuntime(RayConfig rayConfig, FunctionManager functionManager) {
|
||||
super(rayConfig, functionManager);
|
||||
public RayDevRuntime(RayConfig rayConfig) {
|
||||
super(rayConfig);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
if (rayConfig.getJobId().isNil()) {
|
||||
rayConfig.setJobId(nextJobId());
|
||||
}
|
||||
taskExecutor = new LocalModeTaskExecutor(this);
|
||||
workerContext = new LocalModeWorkerContext(rayConfig.getJobId());
|
||||
objectStore = new LocalModeObjectStore(workerContext);
|
||||
taskSubmitter = new LocalModeTaskSubmitter(this, (LocalModeObjectStore) objectStore,
|
||||
rayConfig.numberExecThreadsForDevRuntime);
|
||||
taskSubmitter = new LocalModeTaskSubmitter(this, taskExecutor,
|
||||
(LocalModeObjectStore) objectStore);
|
||||
((LocalModeObjectStore) objectStore).addObjectPutCallback(
|
||||
objectId -> ((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId));
|
||||
objectId -> {
|
||||
if (taskSubmitter != null) {
|
||||
((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -60,6 +73,8 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
||||
|
||||
@Override
|
||||
public void setAsyncContext(Object asyncContext) {
|
||||
Preconditions.checkArgument(asyncContext == null);
|
||||
super.setAsyncContext(asyncContext);
|
||||
}
|
||||
|
||||
private JobId nextJobId() {
|
||||
|
||||
@@ -1,215 +0,0 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Callable;
|
||||
import org.ray.api.BaseActor;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.function.PyActorClass;
|
||||
import org.ray.api.function.PyActorMethod;
|
||||
import org.ray.api.function.PyRemoteFunction;
|
||||
import org.ray.api.function.RayFunc;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.api.runtimecontext.RuntimeContext;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* This is a proxy runtime for multi-worker support. It holds multiple {@link RayNativeRuntime}
|
||||
* instances and redirect calls to the correct one based on thread context.
|
||||
*/
|
||||
public class RayMultiWorkerNativeRuntime implements RayRuntime {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(RayMultiWorkerNativeRuntime.class);
|
||||
|
||||
private final FunctionManager functionManager;
|
||||
|
||||
/**
|
||||
* The number of workers per worker process.
|
||||
*/
|
||||
private final int numWorkers;
|
||||
/**
|
||||
* The worker threads.
|
||||
*/
|
||||
private final Thread[] threads;
|
||||
/**
|
||||
* The {@link RayNativeRuntime} instances of workers.
|
||||
*/
|
||||
private final RayNativeRuntime[] runtimes;
|
||||
/**
|
||||
* The {@link RayNativeRuntime} instance of current thread.
|
||||
*/
|
||||
private final ThreadLocal<RayNativeRuntime> currentThreadRuntime = new ThreadLocal<>();
|
||||
|
||||
public RayMultiWorkerNativeRuntime(RayConfig rayConfig, FunctionManager functionManager) {
|
||||
this.functionManager = functionManager;
|
||||
Preconditions.checkState(
|
||||
rayConfig.runMode == RunMode.CLUSTER && rayConfig.workerMode == WorkerType.WORKER);
|
||||
Preconditions.checkState(rayConfig.numWorkersPerProcess > 0,
|
||||
"numWorkersPerProcess must be greater than 0.");
|
||||
numWorkers = rayConfig.numWorkersPerProcess;
|
||||
runtimes = new RayNativeRuntime[numWorkers];
|
||||
threads = new Thread[numWorkers];
|
||||
|
||||
LOGGER.info("Starting {} workers.", numWorkers);
|
||||
|
||||
for (int i = 0; i < numWorkers; i++) {
|
||||
final int workerIndex = i;
|
||||
threads[i] = new Thread(() -> {
|
||||
RayNativeRuntime runtime = new RayNativeRuntime(rayConfig, functionManager);
|
||||
runtimes[workerIndex] = runtime;
|
||||
currentThreadRuntime.set(runtime);
|
||||
runtime.run();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public void run() {
|
||||
for (int i = 0; i < numWorkers; i++) {
|
||||
threads[i].start();
|
||||
}
|
||||
for (int i = 0; i < numWorkers; i++) {
|
||||
try {
|
||||
threads[i].join();
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void shutdown() {
|
||||
for (int i = 0; i < numWorkers; i++) {
|
||||
runtimes[i].shutdown();
|
||||
}
|
||||
for (int i = 0; i < numWorkers; i++) {
|
||||
try {
|
||||
threads[i].join();
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public RayNativeRuntime getCurrentRuntime() {
|
||||
RayNativeRuntime currentRuntime = currentThreadRuntime.get();
|
||||
Preconditions.checkNotNull(currentRuntime,
|
||||
"RayRuntime is not set on current thread."
|
||||
+ " If you want to use Ray API in your own threads,"
|
||||
+ " please wrap your `Runnable`s or `Callable`s with"
|
||||
+ " `Ray.wrapRunnable` or `Ray.wrapCallable`.");
|
||||
return currentRuntime;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> RayObject<T> put(T obj) {
|
||||
return getCurrentRuntime().put(obj);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> T get(ObjectId objectId) {
|
||||
return getCurrentRuntime().get(objectId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> List<T> get(List<ObjectId> objectIds) {
|
||||
return getCurrentRuntime().get(objectIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
|
||||
return getCurrentRuntime().wait(waitList, numReturns, timeoutMs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void free(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
getCurrentRuntime().free(objectIds, localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
getCurrentRuntime().setResource(resourceName, capacity, nodeId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void killActor(BaseActor actor, boolean noReconstruction) {
|
||||
getCurrentRuntime().killActor(actor, noReconstruction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject call(RayFunc func, Object[] args, CallOptions options) {
|
||||
return getCurrentRuntime().call(func, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject call(PyRemoteFunction pyRemoteFunction, Object[] args,
|
||||
CallOptions options) {
|
||||
return getCurrentRuntime().call(pyRemoteFunction, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callActor(RayActor<?> actor, RayFunc func, Object[] args) {
|
||||
return getCurrentRuntime().callActor(actor, func, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callActor(RayPyActor pyActor, PyActorMethod pyActorMethod, Object[] args) {
|
||||
return getCurrentRuntime().callActor(pyActor, pyActorMethod, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> RayActor<T> createActor(RayFunc actorFactoryFunc, Object[] args,
|
||||
ActorCreationOptions options) {
|
||||
return getCurrentRuntime().createActor(actorFactoryFunc, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayPyActor createActor(PyActorClass pyActorClass, Object[] args,
|
||||
ActorCreationOptions options) {
|
||||
return getCurrentRuntime().createActor(pyActorClass, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RuntimeContext getRuntimeContext() {
|
||||
return getCurrentRuntime().getRuntimeContext();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getAsyncContext() {
|
||||
return getCurrentRuntime();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setAsyncContext(Object asyncContext) {
|
||||
currentThreadRuntime.set((RayNativeRuntime) asyncContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Runnable wrapRunnable(Runnable runnable) {
|
||||
Object asyncContext = getAsyncContext();
|
||||
return () -> {
|
||||
setAsyncContext(asyncContext);
|
||||
runnable.run();
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public Callable wrapCallable(Callable callable) {
|
||||
Object asyncContext = getAsyncContext();
|
||||
return () -> {
|
||||
setAsyncContext(asyncContext);
|
||||
return callable.call();
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package org.ray.runtime;
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.ray.api.BaseActor;
|
||||
@@ -11,7 +10,6 @@ import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.NativeWorkerContext;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.gcs.GcsClient;
|
||||
import org.ray.runtime.gcs.GcsClientOptions;
|
||||
import org.ray.runtime.gcs.RedisClient;
|
||||
@@ -34,10 +32,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
|
||||
private RunManager manager = null;
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private long nativeCoreWorkerPointer;
|
||||
|
||||
static {
|
||||
LOGGER.debug("Loading native libraries.");
|
||||
@@ -55,14 +49,13 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
|
||||
JniUtils.loadLibrary("core_worker_library_java", true);
|
||||
LOGGER.debug("Native libraries loaded.");
|
||||
// Reset library path at runtime.
|
||||
resetLibraryPath(rayConfig);
|
||||
try {
|
||||
FileUtils.forceMkdir(new File(rayConfig.logDir));
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to create the log directory.", e);
|
||||
}
|
||||
nativeSetup(rayConfig.logDir, rayConfig.rayletConfigParameters);
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(RayNativeRuntime::nativeShutdownHook));
|
||||
}
|
||||
|
||||
private static void resetLibraryPath(RayConfig rayConfig) {
|
||||
@@ -71,11 +64,12 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
JniUtils.resetLibraryPath(libraryPath);
|
||||
}
|
||||
|
||||
public RayNativeRuntime(RayConfig rayConfig, FunctionManager functionManager) {
|
||||
super(rayConfig, functionManager);
|
||||
// Reset library path at runtime.
|
||||
resetLibraryPath(rayConfig);
|
||||
public RayNativeRuntime(RayConfig rayConfig) {
|
||||
super(rayConfig);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
if (rayConfig.getRedisAddress() == null) {
|
||||
manager = new RunManager(rayConfig);
|
||||
manager.startRayProcesses(true);
|
||||
@@ -86,21 +80,21 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
if (rayConfig.getJobId() == JobId.NIL) {
|
||||
rayConfig.setJobId(gcsClient.nextJobId());
|
||||
}
|
||||
int numWorkersPerProcess =
|
||||
rayConfig.workerMode == WorkerType.DRIVER ? 1 : rayConfig.numWorkersPerProcess;
|
||||
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
|
||||
nativeCoreWorkerPointer = nativeInitCoreWorker(rayConfig.workerMode.getNumber(),
|
||||
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName,
|
||||
nativeInitialize(rayConfig.workerMode.getNumber(),
|
||||
rayConfig.nodeIp, rayConfig.getNodeManagerPort(),
|
||||
rayConfig.workerMode == WorkerType.DRIVER ? System.getProperty("user.dir") : "",
|
||||
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName,
|
||||
(rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(),
|
||||
new GcsClientOptions(rayConfig));
|
||||
Preconditions.checkState(nativeCoreWorkerPointer != 0);
|
||||
new GcsClientOptions(rayConfig), numWorkersPerProcess,
|
||||
rayConfig.logDir, rayConfig.rayletConfigParameters);
|
||||
|
||||
workerContext = new NativeWorkerContext(nativeCoreWorkerPointer);
|
||||
taskExecutor = new NativeTaskExecutor(nativeCoreWorkerPointer, this);
|
||||
objectStore = new NativeObjectStore(workerContext, nativeCoreWorkerPointer);
|
||||
taskSubmitter = new NativeTaskSubmitter(nativeCoreWorkerPointer);
|
||||
|
||||
// register
|
||||
registerWorker();
|
||||
taskExecutor = new NativeTaskExecutor(this);
|
||||
workerContext = new NativeWorkerContext();
|
||||
objectStore = new NativeObjectStore(workerContext);
|
||||
taskSubmitter = new NativeTaskSubmitter();
|
||||
|
||||
LOGGER.info("RayNativeRuntime started with store {}, raylet {}",
|
||||
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName);
|
||||
@@ -108,10 +102,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
|
||||
@Override
|
||||
public void shutdown() {
|
||||
if (nativeCoreWorkerPointer != 0) {
|
||||
nativeDestroyCoreWorker(nativeCoreWorkerPointer);
|
||||
nativeCoreWorkerPointer = 0;
|
||||
}
|
||||
nativeShutdown();
|
||||
if (null != manager) {
|
||||
manager.cleanup();
|
||||
manager = null;
|
||||
@@ -131,75 +122,56 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
if (nodeId == null) {
|
||||
nodeId = UniqueId.NIL;
|
||||
}
|
||||
nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes());
|
||||
nativeSetResource(resourceName, capacity, nodeId.getBytes());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void killActor(BaseActor actor, boolean noReconstruction) {
|
||||
nativeKillActor(nativeCoreWorkerPointer, actor.getId().getBytes(), noReconstruction);
|
||||
nativeKillActor(actor.getId().getBytes(), noReconstruction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getAsyncContext() {
|
||||
return null;
|
||||
return new AsyncContext(workerContext.getCurrentWorkerId(),
|
||||
workerContext.getCurrentClassLoader());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setAsyncContext(Object asyncContext) {
|
||||
nativeSetCoreWorker(((AsyncContext) asyncContext).workerId.getBytes());
|
||||
workerContext.setCurrentClassLoader(((AsyncContext) asyncContext).currentClassLoader);
|
||||
super.setAsyncContext(asyncContext);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
nativeRunTaskExecutor(nativeCoreWorkerPointer);
|
||||
Preconditions.checkState(rayConfig.workerMode == WorkerType.WORKER);
|
||||
nativeRunTaskExecutor(taskExecutor);
|
||||
}
|
||||
|
||||
public long getNativeCoreWorkerPointer() {
|
||||
return nativeCoreWorkerPointer;
|
||||
}
|
||||
private static native void nativeInitialize(int workerMode, String ndoeIpAddress,
|
||||
int nodeManagerPort, String driverName, String storeSocket, String rayletSocket,
|
||||
byte[] jobId, GcsClientOptions gcsClientOptions, int numWorkersPerProcess,
|
||||
String logDir, Map<String, String> rayletConfigParameters);
|
||||
|
||||
public TaskExecutor getTaskExecutor() {
|
||||
return taskExecutor;
|
||||
}
|
||||
private static native void nativeRunTaskExecutor(TaskExecutor taskExecutor);
|
||||
|
||||
/**
|
||||
* Register this worker or driver to GCS.
|
||||
*/
|
||||
private void registerWorker() {
|
||||
RedisClient redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
|
||||
Map<String, String> workerInfo = new HashMap<>();
|
||||
String workerId = new String(workerContext.getCurrentWorkerId().getBytes());
|
||||
if (rayConfig.workerMode == WorkerType.DRIVER) {
|
||||
workerInfo.put("node_ip_address", rayConfig.nodeIp);
|
||||
workerInfo.put("driver_id", workerId);
|
||||
workerInfo.put("start_time", String.valueOf(System.currentTimeMillis()));
|
||||
workerInfo.put("plasma_store_socket", rayConfig.objectStoreSocketName);
|
||||
workerInfo.put("raylet_socket", rayConfig.rayletSocketName);
|
||||
workerInfo.put("name", System.getProperty("user.dir"));
|
||||
//TODO: worker.redis_client.hmset(b"Drivers:" + worker.workerId, driver_info)
|
||||
redisClient.hmset("Drivers:" + workerId, workerInfo);
|
||||
} else {
|
||||
workerInfo.put("node_ip_address", rayConfig.nodeIp);
|
||||
workerInfo.put("plasma_store_socket", rayConfig.objectStoreSocketName);
|
||||
workerInfo.put("raylet_socket", rayConfig.rayletSocketName);
|
||||
//TODO: b"Workers:" + worker.workerId,
|
||||
redisClient.hmset("Workers:" + workerId, workerInfo);
|
||||
private static native void nativeShutdown();
|
||||
|
||||
private static native void nativeSetResource(String resourceName, double capacity, byte[] nodeId);
|
||||
|
||||
private static native void nativeKillActor(byte[] actorId, boolean noReconstruction);
|
||||
|
||||
private static native void nativeSetCoreWorker(byte[] workerId);
|
||||
|
||||
static class AsyncContext {
|
||||
|
||||
public final UniqueId workerId;
|
||||
public final ClassLoader currentClassLoader;
|
||||
|
||||
AsyncContext(UniqueId workerId, ClassLoader currentClassLoader) {
|
||||
this.workerId = workerId;
|
||||
this.currentClassLoader = currentClassLoader;
|
||||
}
|
||||
}
|
||||
|
||||
private static native long nativeInitCoreWorker(int workerMode, String storeSocket,
|
||||
String rayletSocket, String nodeIpAddress, int nodeManagerPort, byte[] jobId,
|
||||
GcsClientOptions gcsClientOptions);
|
||||
|
||||
private static native void nativeRunTaskExecutor(long nativeCoreWorkerPointer);
|
||||
|
||||
private static native void nativeDestroyCoreWorker(long nativeCoreWorkerPointer);
|
||||
|
||||
private static native void nativeSetup(String logDir, Map<String, String> rayletConfigParameters);
|
||||
|
||||
private static native void nativeShutdownHook();
|
||||
|
||||
private static native void nativeSetResource(long conn, String resourceName, double capacity,
|
||||
byte[] nodeId);
|
||||
|
||||
private static native void nativeKillActor(long nativeCoreWorkerPointer, byte[] actorId,
|
||||
boolean noReconstruction);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.WorkerContext;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.gcs.GcsClient;
|
||||
import org.ray.runtime.object.ObjectStore;
|
||||
|
||||
/**
|
||||
* This interface is required to make {@link RayRuntimeProxy} work.
|
||||
*/
|
||||
public interface RayRuntimeInternal extends RayRuntime {
|
||||
|
||||
/**
|
||||
* Start runtime.
|
||||
*/
|
||||
void start();
|
||||
|
||||
WorkerContext getWorkerContext();
|
||||
|
||||
ObjectStore getObjectStore();
|
||||
|
||||
FunctionManager getFunctionManager();
|
||||
|
||||
RayConfig getRayConfig();
|
||||
|
||||
GcsClient getGcsClient();
|
||||
|
||||
void setIsContextSet(boolean isContextSet);
|
||||
|
||||
void run();
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import java.lang.reflect.InvocationHandler;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.lang.reflect.Method;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
|
||||
/**
|
||||
* Protect a ray runtime with context checks for all methods of {@link RayRuntime} (except {@link
|
||||
* RayRuntime#shutdown(boolean)}).
|
||||
*/
|
||||
public class RayRuntimeProxy implements InvocationHandler {
|
||||
|
||||
/**
|
||||
* The original runtime.
|
||||
*/
|
||||
private AbstractRayRuntime obj;
|
||||
|
||||
private RayRuntimeProxy(AbstractRayRuntime obj) {
|
||||
this.obj = obj;
|
||||
}
|
||||
|
||||
public AbstractRayRuntime getRuntimeObject() {
|
||||
return obj;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a new instance of {@link RayRuntimeInternal} with additional context check.
|
||||
*/
|
||||
static RayRuntimeInternal newInstance(AbstractRayRuntime obj) {
|
||||
return (RayRuntimeInternal) java.lang.reflect.Proxy
|
||||
.newProxyInstance(obj.getClass().getClassLoader(), new Class<?>[]{RayRuntimeInternal.class},
|
||||
new RayRuntimeProxy(obj));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
|
||||
if (isInterfaceMethod(method) && !method.getName().equals("shutdown") && !method.getName()
|
||||
.equals("setAsyncContext")) {
|
||||
checkIsContextSet();
|
||||
}
|
||||
try {
|
||||
return method.invoke(obj, args);
|
||||
} catch (InvocationTargetException e) {
|
||||
if (e.getCause() != null) {
|
||||
throw e.getCause();
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether the method is defined in the {@link RayRuntime} interface.
|
||||
*/
|
||||
private boolean isInterfaceMethod(Method method) {
|
||||
try {
|
||||
RayRuntime.class.getMethod(method.getName(), method.getParameterTypes());
|
||||
return true;
|
||||
} catch (NoSuchMethodException e) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if thread context is set.
|
||||
* <p/>
|
||||
* This method should be invoked at the beginning of most public methods of {@link RayRuntime},
|
||||
* otherwise the native code might crash due to thread local core worker was not set. We check it
|
||||
* for {@link AbstractRayRuntime} instead of {@link RayNativeRuntime} because we want to catch the
|
||||
* error even if the application runs in {@link RunMode#SINGLE_PROCESS} mode.
|
||||
*/
|
||||
private void checkIsContextSet() {
|
||||
if (!obj.isContextSet.get()) {
|
||||
throw new RayException(
|
||||
"`Ray.wrap***` is not called on the current thread."
|
||||
+ " If you want to use Ray API in your own threads,"
|
||||
+ " please wrap your executable with `Ray.wrap***`.");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,11 +7,7 @@ import java.io.ObjectInput;
|
||||
import java.io.ObjectOutput;
|
||||
import java.util.List;
|
||||
import org.ray.api.BaseActor;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.RayMultiWorkerNativeRuntime;
|
||||
import org.ray.runtime.RayNativeRuntime;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
|
||||
/**
|
||||
@@ -20,20 +16,17 @@ import org.ray.runtime.generated.Common.Language;
|
||||
*/
|
||||
public abstract class NativeRayActor implements BaseActor, Externalizable {
|
||||
|
||||
/**
|
||||
* Address of core worker.
|
||||
*/
|
||||
long nativeCoreWorkerPointer;
|
||||
/**
|
||||
* ID of the actor.
|
||||
*/
|
||||
byte[] actorId;
|
||||
|
||||
NativeRayActor(long nativeCoreWorkerPointer, byte[] actorId) {
|
||||
Preconditions.checkState(nativeCoreWorkerPointer != 0);
|
||||
private Language language;
|
||||
|
||||
NativeRayActor(byte[] actorId, Language language) {
|
||||
Preconditions.checkState(!ActorId.fromBytes(actorId).isNil());
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
this.actorId = actorId;
|
||||
this.language = language;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -42,14 +35,12 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
|
||||
NativeRayActor() {
|
||||
}
|
||||
|
||||
public static NativeRayActor create(long nativeCoreWorkerPointer, byte[] actorId,
|
||||
Language language) {
|
||||
Preconditions.checkState(nativeCoreWorkerPointer != 0);
|
||||
public static NativeRayActor create(byte[] actorId, Language language) {
|
||||
switch (language) {
|
||||
case JAVA:
|
||||
return new NativeRayJavaActor(nativeCoreWorkerPointer, actorId);
|
||||
return new NativeRayJavaActor(actorId);
|
||||
case PYTHON:
|
||||
return new NativeRayPyActor(nativeCoreWorkerPointer, actorId);
|
||||
return new NativeRayPyActor(actorId);
|
||||
default:
|
||||
throw new IllegalStateException("Unknown actor handle language: " + language);
|
||||
}
|
||||
@@ -61,18 +52,19 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
|
||||
}
|
||||
|
||||
public Language getLanguage() {
|
||||
return Language.forNumber(nativeGetLanguage(nativeCoreWorkerPointer, actorId));
|
||||
return language;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeExternal(ObjectOutput out) throws IOException {
|
||||
out.writeObject(toBytes());
|
||||
out.writeObject(nativeSerialize(actorId));
|
||||
out.writeObject(language);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
|
||||
nativeCoreWorkerPointer = getNativeCoreWorkerPointer();
|
||||
actorId = nativeDeserialize(nativeCoreWorkerPointer, (byte[]) in.readObject());
|
||||
actorId = nativeDeserialize((byte[]) in.readObject());
|
||||
language = (Language) in.readObject();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -81,7 +73,7 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
|
||||
* @return the bytes of the actor handle
|
||||
*/
|
||||
public byte[] toBytes() {
|
||||
return nativeSerialize(nativeCoreWorkerPointer, actorId);
|
||||
return nativeSerialize(actorId);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -90,21 +82,10 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
|
||||
* @return the bytes of an actor handle
|
||||
*/
|
||||
public static NativeRayActor fromBytes(byte[] bytes) {
|
||||
long nativeCoreWorkerPointer = getNativeCoreWorkerPointer();
|
||||
byte[] actorId = nativeDeserialize(nativeCoreWorkerPointer, bytes);
|
||||
Language language = Language.forNumber(nativeGetLanguage(nativeCoreWorkerPointer, actorId));
|
||||
byte[] actorId = nativeDeserialize(bytes);
|
||||
Language language = Language.forNumber(nativeGetLanguage(actorId));
|
||||
Preconditions.checkNotNull(language);
|
||||
return create(nativeCoreWorkerPointer, actorId, language);
|
||||
}
|
||||
|
||||
private static long getNativeCoreWorkerPointer() {
|
||||
RayRuntime runtime = Ray.internal();
|
||||
if (runtime instanceof RayMultiWorkerNativeRuntime) {
|
||||
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
|
||||
}
|
||||
Preconditions.checkState(runtime instanceof RayNativeRuntime);
|
||||
|
||||
return ((RayNativeRuntime) runtime).getNativeCoreWorkerPointer();
|
||||
return create(actorId, language);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -112,13 +93,11 @@ public abstract class NativeRayActor implements BaseActor, Externalizable {
|
||||
// TODO(zhijunfu): do we need to free the ActorHandle in core worker?
|
||||
}
|
||||
|
||||
private static native int nativeGetLanguage(
|
||||
long nativeCoreWorkerPointer, byte[] actorId);
|
||||
private static native int nativeGetLanguage(byte[] actorId);
|
||||
|
||||
static native List<String> nativeGetActorCreationTaskFunctionDescriptor(
|
||||
long nativeCoreWorkerPointer, byte[] actorId);
|
||||
static native List<String> nativeGetActorCreationTaskFunctionDescriptor(byte[] actorId);
|
||||
|
||||
private static native byte[] nativeSerialize(long nativeCoreWorkerPointer, byte[] actorId);
|
||||
private static native byte[] nativeSerialize(byte[] actorId);
|
||||
|
||||
private static native byte[] nativeDeserialize(long nativeCoreWorkerPointer, byte[] data);
|
||||
private static native byte[] nativeDeserialize(byte[] data);
|
||||
}
|
||||
|
||||
@@ -11,8 +11,8 @@ import org.ray.runtime.generated.Common.Language;
|
||||
*/
|
||||
public class NativeRayJavaActor extends NativeRayActor implements RayActor {
|
||||
|
||||
NativeRayJavaActor(long nativeCoreWorkerPointer, byte[] actorId) {
|
||||
super(nativeCoreWorkerPointer, actorId);
|
||||
NativeRayJavaActor(byte[] actorId) {
|
||||
super(actorId, Language.JAVA);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -11,8 +11,8 @@ import org.ray.runtime.generated.Common.Language;
|
||||
*/
|
||||
public class NativeRayPyActor extends NativeRayActor implements RayPyActor {
|
||||
|
||||
NativeRayPyActor(long nativeCoreWorkerPointer, byte[] actorId) {
|
||||
super(nativeCoreWorkerPointer, actorId);
|
||||
NativeRayPyActor(byte[] actorId) {
|
||||
super(actorId, Language.PYTHON);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -24,12 +24,12 @@ public class NativeRayPyActor extends NativeRayActor implements RayPyActor {
|
||||
|
||||
@Override
|
||||
public String getModuleName() {
|
||||
return nativeGetActorCreationTaskFunctionDescriptor(nativeCoreWorkerPointer, actorId).get(0);
|
||||
return nativeGetActorCreationTaskFunctionDescriptor(actorId).get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getClassName() {
|
||||
return nativeGetActorCreationTaskFunctionDescriptor(nativeCoreWorkerPointer, actorId).get(1);
|
||||
return nativeGetActorCreationTaskFunctionDescriptor(actorId).get(1);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -93,10 +93,6 @@ public class RayConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Number of threads that execute tasks.
|
||||
*/
|
||||
public final int numberExecThreadsForDevRuntime;
|
||||
|
||||
public final int numWorkersPerProcess;
|
||||
|
||||
@@ -220,9 +216,6 @@ public class RayConfig {
|
||||
jobResourcePath = null;
|
||||
}
|
||||
|
||||
// Number of threads that execute tasks.
|
||||
numberExecThreadsForDevRuntime = config.getInt("ray.dev-runtime.execution-parallelism");
|
||||
|
||||
numWorkersPerProcess = config.getInt("ray.raylet.config.num_workers_per_process_java");
|
||||
|
||||
gcsServiceEnabled = System.getenv("RAY_GCS_SERVICE_ENABLED") == null ||
|
||||
|
||||
@@ -16,6 +16,7 @@ public class LocalModeWorkerContext implements WorkerContext {
|
||||
|
||||
private final JobId jobId;
|
||||
private ThreadLocal<TaskSpec> currentTask = new ThreadLocal<>();
|
||||
private final ThreadLocal<UniqueId> currentWorkerId = new ThreadLocal<>();
|
||||
|
||||
public LocalModeWorkerContext(JobId jobId) {
|
||||
this.jobId = jobId;
|
||||
@@ -23,7 +24,11 @@ public class LocalModeWorkerContext implements WorkerContext {
|
||||
|
||||
@Override
|
||||
public UniqueId getCurrentWorkerId() {
|
||||
throw new UnsupportedOperationException();
|
||||
return currentWorkerId.get();
|
||||
}
|
||||
|
||||
public void setCurrentWorkerId(UniqueId workerId) {
|
||||
currentWorkerId.set(workerId);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -12,61 +12,52 @@ import org.ray.runtime.generated.Common.TaskType;
|
||||
*/
|
||||
public class NativeWorkerContext implements WorkerContext {
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private final long nativeCoreWorkerPointer;
|
||||
|
||||
private ClassLoader currentClassLoader;
|
||||
|
||||
public NativeWorkerContext(long nativeCoreWorkerPointer) {
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
}
|
||||
private final ThreadLocal<ClassLoader> currentClassLoader = new ThreadLocal<>();
|
||||
|
||||
@Override
|
||||
public UniqueId getCurrentWorkerId() {
|
||||
return UniqueId.fromByteBuffer(nativeGetCurrentWorkerId(nativeCoreWorkerPointer));
|
||||
return UniqueId.fromByteBuffer(nativeGetCurrentWorkerId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public JobId getCurrentJobId() {
|
||||
return JobId.fromByteBuffer(nativeGetCurrentJobId(nativeCoreWorkerPointer));
|
||||
return JobId.fromByteBuffer(nativeGetCurrentJobId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ActorId getCurrentActorId() {
|
||||
return ActorId.fromByteBuffer(nativeGetCurrentActorId(nativeCoreWorkerPointer));
|
||||
return ActorId.fromByteBuffer(nativeGetCurrentActorId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ClassLoader getCurrentClassLoader() {
|
||||
return currentClassLoader;
|
||||
return currentClassLoader.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setCurrentClassLoader(ClassLoader currentClassLoader) {
|
||||
if (this.currentClassLoader != currentClassLoader) {
|
||||
this.currentClassLoader = currentClassLoader;
|
||||
if (this.currentClassLoader.get() != currentClassLoader) {
|
||||
this.currentClassLoader.set(currentClassLoader);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskType getCurrentTaskType() {
|
||||
return TaskType.forNumber(nativeGetCurrentTaskType(nativeCoreWorkerPointer));
|
||||
return TaskType.forNumber(nativeGetCurrentTaskType());
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskId getCurrentTaskId() {
|
||||
return TaskId.fromByteBuffer(nativeGetCurrentTaskId(nativeCoreWorkerPointer));
|
||||
return TaskId.fromByteBuffer(nativeGetCurrentTaskId());
|
||||
}
|
||||
|
||||
private static native int nativeGetCurrentTaskType(long nativeCoreWorkerPointer);
|
||||
private static native int nativeGetCurrentTaskType();
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentTaskId(long nativeCoreWorkerPointer);
|
||||
private static native ByteBuffer nativeGetCurrentTaskId();
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentJobId(long nativeCoreWorkerPointer);
|
||||
private static native ByteBuffer nativeGetCurrentJobId();
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentWorkerId(long nativeCoreWorkerPointer);
|
||||
private static native ByteBuffer nativeGetCurrentWorkerId();
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentActorId(long nativeCoreWorkerPointer);
|
||||
private static native ByteBuffer nativeGetCurrentActorId();
|
||||
}
|
||||
|
||||
@@ -6,15 +6,15 @@ import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.runtimecontext.NodeInfo;
|
||||
import org.ray.api.runtimecontext.RuntimeContext;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayRuntimeInternal;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
|
||||
public class RuntimeContextImpl implements RuntimeContext {
|
||||
|
||||
private AbstractRayRuntime runtime;
|
||||
private RayRuntimeInternal runtime;
|
||||
|
||||
public RuntimeContextImpl(AbstractRayRuntime runtime) {
|
||||
public RuntimeContextImpl(RayRuntimeInternal runtime) {
|
||||
this.runtime = runtime;
|
||||
}
|
||||
|
||||
|
||||
@@ -15,56 +15,48 @@ public class NativeObjectStore extends ObjectStore {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(NativeObjectStore.class);
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private final long nativeCoreWorkerPointer;
|
||||
|
||||
public NativeObjectStore(WorkerContext workerContext, long nativeCoreWorkerPointer) {
|
||||
public NativeObjectStore(WorkerContext workerContext) {
|
||||
super(workerContext);
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObjectId putRaw(NativeRayObject obj) {
|
||||
return new ObjectId(nativePut(nativeCoreWorkerPointer, obj));
|
||||
return new ObjectId(nativePut(obj));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void putRaw(NativeRayObject obj, ObjectId objectId) {
|
||||
nativePut(nativeCoreWorkerPointer, objectId.getBytes(), obj);
|
||||
nativePut(objectId.getBytes(), obj);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<NativeRayObject> getRaw(List<ObjectId> objectIds, long timeoutMs) {
|
||||
return nativeGet(nativeCoreWorkerPointer, toBinaryList(objectIds), timeoutMs);
|
||||
return nativeGet(toBinaryList(objectIds), timeoutMs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
|
||||
return nativeWait(nativeCoreWorkerPointer, toBinaryList(objectIds), numObjects, timeoutMs);
|
||||
return nativeWait(toBinaryList(objectIds), numObjects, timeoutMs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
nativeDelete(nativeCoreWorkerPointer, toBinaryList(objectIds), localOnly, deleteCreatingTasks);
|
||||
nativeDelete(toBinaryList(objectIds), localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
private static List<byte[]> toBinaryList(List<ObjectId> ids) {
|
||||
return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static native byte[] nativePut(long nativeCoreWorkerPointer, NativeRayObject obj);
|
||||
private static native byte[] nativePut(NativeRayObject obj);
|
||||
|
||||
private static native void nativePut(long nativeCoreWorkerPointer, byte[] objectId,
|
||||
NativeRayObject obj);
|
||||
private static native void nativePut(byte[] objectId, NativeRayObject obj);
|
||||
|
||||
private static native List<NativeRayObject> nativeGet(long nativeCoreWorkerPointer,
|
||||
List<byte[]> ids, long timeoutMs);
|
||||
private static native List<NativeRayObject> nativeGet(List<byte[]> ids, long timeoutMs);
|
||||
|
||||
private static native List<Boolean> nativeWait(long nativeCoreWorkerPointer,
|
||||
List<byte[]> objectIds, int numObjects, long timeoutMs);
|
||||
private static native List<Boolean> nativeWait(List<byte[]> objectIds, int numObjects,
|
||||
long timeoutMs);
|
||||
|
||||
private static native void nativeDelete(long nativeCoreWorkerPointer, List<byte[]> objectIds,
|
||||
boolean localOnly, boolean deleteCreatingTasks);
|
||||
private static native void nativeDelete(List<byte[]> objectIds, boolean localOnly,
|
||||
boolean deleteCreatingTasks);
|
||||
}
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
package org.ray.runtime.runner.worker;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.RayMultiWorkerNativeRuntime;
|
||||
import org.ray.runtime.RayNativeRuntime;
|
||||
import org.ray.runtime.RayRuntimeInternal;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -25,14 +23,7 @@ public class DefaultWorker {
|
||||
});
|
||||
Ray.init();
|
||||
LOGGER.info("Worker started.");
|
||||
RayRuntime runtime = Ray.internal();
|
||||
if (runtime instanceof RayNativeRuntime) {
|
||||
((RayNativeRuntime)runtime).run();
|
||||
} else if (runtime instanceof RayMultiWorkerNativeRuntime) {
|
||||
((RayMultiWorkerNativeRuntime)runtime).run();
|
||||
} else {
|
||||
throw new RuntimeException("Unknown RayRuntime: " + runtime);
|
||||
}
|
||||
((RayRuntimeInternal) Ray.internal()).run();
|
||||
} catch (Exception e) {
|
||||
LOGGER.error("Failed to start worker.", e);
|
||||
}
|
||||
|
||||
@@ -6,9 +6,7 @@ import java.util.List;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayMultiWorkerNativeRuntime;
|
||||
import org.ray.runtime.RayRuntimeInternal;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.object.ObjectSerializer;
|
||||
@@ -43,11 +41,7 @@ public class ArgumentsBuilder {
|
||||
} else {
|
||||
value = ObjectSerializer.serialize(arg);
|
||||
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
|
||||
RayRuntime runtime = Ray.internal();
|
||||
if (runtime instanceof RayMultiWorkerNativeRuntime) {
|
||||
runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime();
|
||||
}
|
||||
id = ((AbstractRayRuntime) runtime).getObjectStore()
|
||||
id = ((RayRuntimeInternal) Ray.internal()).getObjectStore()
|
||||
.putRaw(value);
|
||||
value = null;
|
||||
}
|
||||
|
||||
@@ -1,17 +1,40 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.RayRuntimeInternal;
|
||||
import org.ray.runtime.task.LocalModeTaskExecutor.LocalActorContext;
|
||||
|
||||
/**
|
||||
* Task executor for local mode.
|
||||
*/
|
||||
public class LocalModeTaskExecutor extends TaskExecutor {
|
||||
public class LocalModeTaskExecutor extends TaskExecutor<LocalActorContext> {
|
||||
|
||||
public LocalModeTaskExecutor(AbstractRayRuntime runtime) {
|
||||
static class LocalActorContext extends TaskExecutor.ActorContext {
|
||||
|
||||
/**
|
||||
* The worker ID of the actor.
|
||||
*/
|
||||
private final UniqueId workerId;
|
||||
|
||||
public LocalActorContext(UniqueId workerId) {
|
||||
this.workerId = workerId;
|
||||
}
|
||||
|
||||
public UniqueId getWorkerId() {
|
||||
return workerId;
|
||||
}
|
||||
}
|
||||
|
||||
public LocalModeTaskExecutor(RayRuntimeInternal runtime) {
|
||||
super(runtime);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected LocalActorContext createActorContext() {
|
||||
return new LocalActorContext(runtime.getWorkerContext().getCurrentWorkerId());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void maybeSaveCheckpoint(Object actor, ActorId actorId) {
|
||||
}
|
||||
|
||||
@@ -4,16 +4,15 @@ import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.protobuf.ByteString;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Deque;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.stream.Collectors;
|
||||
@@ -21,9 +20,10 @@ import org.ray.api.BaseActor;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
import org.ray.runtime.RayDevRuntime;
|
||||
import org.ray.runtime.RayRuntimeInternal;
|
||||
import org.ray.runtime.actor.LocalModeRayActor;
|
||||
import org.ray.runtime.context.LocalModeWorkerContext;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
@@ -37,6 +37,7 @@ import org.ray.runtime.generated.Common.TaskSpec;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
import org.ray.runtime.object.LocalModeObjectStore;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.task.TaskExecutor.ActorContext;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -49,7 +50,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
|
||||
private final Map<ObjectId, Set<TaskSpec>> waitingTasks = new HashMap<>();
|
||||
private final Object taskAndObjectLock = new Object();
|
||||
private final RayDevRuntime runtime;
|
||||
private final RayRuntimeInternal runtime;
|
||||
private final TaskExecutor taskExecutor;
|
||||
private final LocalModeObjectStore objectStore;
|
||||
|
||||
/// The thread pool to execute actor tasks.
|
||||
@@ -58,17 +60,16 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
/// The thread pool to execute normal tasks.
|
||||
private final ExecutorService normalTaskExecutorService;
|
||||
|
||||
private final Deque<TaskExecutor> idleTaskExecutors = new ArrayDeque<>();
|
||||
private final Map<ActorId, TaskExecutor> actorTaskExecutors = new HashMap<>();
|
||||
private final Object taskExecutorLock = new Object();
|
||||
private final ThreadLocal<TaskExecutor> currentTaskExecutor = new ThreadLocal<>();
|
||||
|
||||
public LocalModeTaskSubmitter(RayDevRuntime runtime, LocalModeObjectStore objectStore,
|
||||
int numberThreads) {
|
||||
private final Map<ActorId, ActorContext> actorContexts = new ConcurrentHashMap<>();
|
||||
|
||||
public LocalModeTaskSubmitter(RayRuntimeInternal runtime, TaskExecutor taskExecutor,
|
||||
LocalModeObjectStore objectStore) {
|
||||
this.runtime = runtime;
|
||||
this.taskExecutor = taskExecutor;
|
||||
this.objectStore = objectStore;
|
||||
// The thread pool that executes normal tasks in parallel.
|
||||
normalTaskExecutorService = Executors.newFixedThreadPool(numberThreads);
|
||||
normalTaskExecutorService = Executors.newCachedThreadPool();
|
||||
// The thread pool that executes actor tasks in parallel.
|
||||
actorTaskExecutorServices = new HashMap<>();
|
||||
}
|
||||
@@ -88,46 +89,6 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the worker of current thread. <br> NOTE: Cannot be used for multi-threading in worker.
|
||||
*/
|
||||
public TaskExecutor getCurrentTaskExecutor() {
|
||||
return currentTaskExecutor.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a worker from the worker pool to run the given task.
|
||||
*/
|
||||
private TaskExecutor getTaskExecutor(TaskSpec task) {
|
||||
TaskExecutor taskExecutor;
|
||||
synchronized (taskExecutorLock) {
|
||||
if (task.getType() == TaskType.ACTOR_TASK) {
|
||||
taskExecutor = actorTaskExecutors.get(getActorId(task));
|
||||
} else if (task.getType() == TaskType.ACTOR_CREATION_TASK) {
|
||||
taskExecutor = new LocalModeTaskExecutor(runtime);
|
||||
actorTaskExecutors.put(getActorId(task), taskExecutor);
|
||||
} else if (idleTaskExecutors.size() > 0) {
|
||||
taskExecutor = idleTaskExecutors.pop();
|
||||
} else {
|
||||
taskExecutor = new LocalModeTaskExecutor(runtime);
|
||||
}
|
||||
}
|
||||
currentTaskExecutor.set(taskExecutor);
|
||||
return taskExecutor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the worker to the worker pool.
|
||||
*/
|
||||
private void returnTaskExecutor(TaskExecutor worker, TaskSpec taskSpec) {
|
||||
currentTaskExecutor.remove();
|
||||
synchronized (taskExecutorLock) {
|
||||
if (taskSpec.getType() == TaskType.NORMAL_TASK) {
|
||||
idleTaskExecutors.push(worker);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Set<ObjectId> getUnreadyObjects(TaskSpec taskSpec) {
|
||||
Set<ObjectId> unreadyObjects = new HashSet<>();
|
||||
// Check whether task arguments are ready.
|
||||
@@ -257,32 +218,11 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
Set<ObjectId> unreadyObjects = getUnreadyObjects(taskSpec);
|
||||
|
||||
final Runnable runnable = () -> {
|
||||
TaskExecutor taskExecutor = getTaskExecutor(taskSpec);
|
||||
try {
|
||||
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
|
||||
.map(arg -> arg.id != null ?
|
||||
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
|
||||
: arg.value)
|
||||
.collect(Collectors.toList());
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
|
||||
List<NativeRayObject> returnObjects = taskExecutor
|
||||
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null);
|
||||
List<ObjectId> returnIds = getReturnIds(taskSpec);
|
||||
for (int i = 0; i < returnIds.size(); i++) {
|
||||
NativeRayObject putObject;
|
||||
if (i >= returnObjects.size()) {
|
||||
// If the task is an actor task or an actor creation task,
|
||||
// put the dummy object in object store, so those tasks which depends on it
|
||||
// can be executed.
|
||||
putObject = new NativeRayObject(new byte[] {1}, null);
|
||||
} else {
|
||||
putObject = returnObjects.get(i);
|
||||
}
|
||||
objectStore.putRaw(putObject, returnIds.get(i));
|
||||
}
|
||||
} finally {
|
||||
returnTaskExecutor(taskExecutor, taskSpec);
|
||||
executeTask(taskSpec);
|
||||
} catch (Exception ex) {
|
||||
LOGGER.error("Unexpected exception when executing a task.", ex);
|
||||
System.exit(-1);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -313,6 +253,52 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
}
|
||||
}
|
||||
|
||||
private void executeTask(TaskSpec taskSpec) {
|
||||
ActorContext actorContext = null;
|
||||
if (taskSpec.getType() == TaskType.ACTOR_TASK) {
|
||||
actorContext = actorContexts.get(getActorId(taskSpec));
|
||||
Preconditions.checkNotNull(actorContext);
|
||||
}
|
||||
taskExecutor.setActorContext(actorContext);
|
||||
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
|
||||
.map(arg -> arg.id != null ?
|
||||
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
|
||||
: arg.value)
|
||||
.collect(Collectors.toList());
|
||||
runtime.setIsContextSet(true);
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
|
||||
UniqueId workerId = actorContext != null
|
||||
? ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId()
|
||||
: UniqueId.randomId();
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId);
|
||||
List<NativeRayObject> returnObjects = taskExecutor
|
||||
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
|
||||
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
|
||||
// Update actor context map ASAP in case objectStore.putRaw triggered the next actor task
|
||||
// on this actor.
|
||||
actorContexts.put(getActorId(taskSpec), taskExecutor.getActorContext());
|
||||
}
|
||||
// Set this flag to true is necessary because at the end of `taskExecutor.execute()`,
|
||||
// this flag will be set to false. And `runtime.getWorkerContext()` requires it to be
|
||||
// true.
|
||||
runtime.setIsContextSet(true);
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null);
|
||||
runtime.setIsContextSet(false);
|
||||
List<ObjectId> returnIds = getReturnIds(taskSpec);
|
||||
for (int i = 0; i < returnIds.size(); i++) {
|
||||
NativeRayObject putObject;
|
||||
if (i >= returnObjects.size()) {
|
||||
// If the task is an actor task or an actor creation task,
|
||||
// put the dummy object in object store, so those tasks which depends on it
|
||||
// can be executed.
|
||||
putObject = new NativeRayObject(new byte[]{1}, null);
|
||||
} else {
|
||||
putObject = returnObjects.get(i);
|
||||
}
|
||||
objectStore.putRaw(putObject, returnIds.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
private static JavaFunctionDescriptor getJavaFunctionDescriptor(TaskSpec taskSpec) {
|
||||
org.ray.runtime.generated.Common.FunctionDescriptor functionDescriptor =
|
||||
taskSpec.getFunctionDescriptor();
|
||||
|
||||
@@ -8,39 +8,42 @@ import org.ray.api.Checkpointable.Checkpoint;
|
||||
import org.ray.api.Checkpointable.CheckpointContext;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayRuntimeInternal;
|
||||
import org.ray.runtime.task.NativeTaskExecutor.NativeActorContext;
|
||||
|
||||
/**
|
||||
* Task executor for cluster mode.
|
||||
*/
|
||||
public class NativeTaskExecutor extends TaskExecutor {
|
||||
public class NativeTaskExecutor extends TaskExecutor<NativeActorContext> {
|
||||
|
||||
// TODO(hchen): Use the C++ config.
|
||||
private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20;
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private final long nativeCoreWorkerPointer;
|
||||
static class NativeActorContext extends TaskExecutor.ActorContext {
|
||||
|
||||
/**
|
||||
* Number of tasks executed since last actor checkpoint.
|
||||
*/
|
||||
private int numTasksSinceLastCheckpoint = 0;
|
||||
/**
|
||||
* Number of tasks executed since last actor checkpoint.
|
||||
*/
|
||||
private int numTasksSinceLastCheckpoint = 0;
|
||||
|
||||
/**
|
||||
* IDs of this actor's previous checkpoints.
|
||||
*/
|
||||
private List<UniqueId> checkpointIds;
|
||||
/**
|
||||
* IDs of this actor's previous checkpoints.
|
||||
*/
|
||||
private List<UniqueId> checkpointIds;
|
||||
|
||||
/**
|
||||
* Timestamp of the last actor checkpoint.
|
||||
*/
|
||||
private long lastCheckpointTimestamp = 0;
|
||||
/**
|
||||
* Timestamp of the last actor checkpoint.
|
||||
*/
|
||||
private long lastCheckpointTimestamp = 0;
|
||||
}
|
||||
|
||||
public NativeTaskExecutor(long nativeCoreWorkerPointer, AbstractRayRuntime runtime) {
|
||||
public NativeTaskExecutor(RayRuntimeInternal runtime) {
|
||||
super(runtime);
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected NativeActorContext createActorContext() {
|
||||
return new NativeActorContext();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -48,15 +51,18 @@ public class NativeTaskExecutor extends TaskExecutor {
|
||||
if (!(actor instanceof Checkpointable)) {
|
||||
return;
|
||||
}
|
||||
NativeActorContext actorContext = getActorContext();
|
||||
CheckpointContext checkpointContext = new CheckpointContext(actorId,
|
||||
++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp);
|
||||
++actorContext.numTasksSinceLastCheckpoint,
|
||||
System.currentTimeMillis() - actorContext.lastCheckpointTimestamp);
|
||||
Checkpointable checkpointable = (Checkpointable) actor;
|
||||
if (!checkpointable.shouldCheckpoint(checkpointContext)) {
|
||||
return;
|
||||
}
|
||||
numTasksSinceLastCheckpoint = 0;
|
||||
lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
UniqueId checkpointId = new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer));
|
||||
actorContext.numTasksSinceLastCheckpoint = 0;
|
||||
actorContext.lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
UniqueId checkpointId = new UniqueId(nativePrepareCheckpoint());
|
||||
List<UniqueId> checkpointIds = actorContext.checkpointIds;
|
||||
checkpointIds.add(checkpointId);
|
||||
if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) {
|
||||
((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0));
|
||||
@@ -70,9 +76,10 @@ public class NativeTaskExecutor extends TaskExecutor {
|
||||
if (!(actor instanceof Checkpointable)) {
|
||||
return;
|
||||
}
|
||||
numTasksSinceLastCheckpoint = 0;
|
||||
lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
checkpointIds = new ArrayList<>();
|
||||
NativeActorContext actorContext = getActorContext();
|
||||
actorContext.numTasksSinceLastCheckpoint = 0;
|
||||
actorContext.lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
actorContext.checkpointIds = new ArrayList<>();
|
||||
List<Checkpoint> availableCheckpoints
|
||||
= runtime.getGcsClient().getCheckpointsForActor(actorId);
|
||||
if (availableCheckpoints.isEmpty()) {
|
||||
@@ -90,13 +97,11 @@ public class NativeTaskExecutor extends TaskExecutor {
|
||||
Preconditions.checkArgument(checkpointValid,
|
||||
"'loadCheckpoint' must return a checkpoint ID that exists in the "
|
||||
+ "'availableCheckpoints' list, or null.");
|
||||
|
||||
nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, checkpointId.getBytes());
|
||||
nativeNotifyActorResumedFromCheckpoint(checkpointId.getBytes());
|
||||
}
|
||||
}
|
||||
|
||||
private static native byte[] nativePrepareCheckpoint(long nativeCoreWorkerPointer);
|
||||
private static native byte[] nativePrepareCheckpoint();
|
||||
|
||||
private static native void nativeNotifyActorResumedFromCheckpoint(long nativeCoreWorkerPointer,
|
||||
byte[] checkpointId);
|
||||
private static native void nativeNotifyActorResumedFromCheckpoint(byte[] checkpointId);
|
||||
}
|
||||
|
||||
@@ -15,30 +15,18 @@ import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
*/
|
||||
public class NativeTaskSubmitter implements TaskSubmitter {
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private final long nativeCoreWorkerPointer;
|
||||
|
||||
public NativeTaskSubmitter(long nativeCoreWorkerPointer) {
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ObjectId> submitTask(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
int numReturns, CallOptions options) {
|
||||
List<byte[]> returnIds = nativeSubmitTask(nativeCoreWorkerPointer, functionDescriptor, args,
|
||||
numReturns, options);
|
||||
List<byte[]> returnIds = nativeSubmitTask(functionDescriptor, args, numReturns, options);
|
||||
return returnIds.stream().map(ObjectId::new).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public BaseActor createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
ActorCreationOptions options) {
|
||||
byte[] actorId = nativeCreateActor(nativeCoreWorkerPointer, functionDescriptor, args,
|
||||
options);
|
||||
return NativeRayActor.create(nativeCoreWorkerPointer, actorId,
|
||||
functionDescriptor.getLanguage());
|
||||
byte[] actorId = nativeCreateActor(functionDescriptor, args, options);
|
||||
return NativeRayActor.create(actorId, functionDescriptor.getLanguage());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -46,24 +34,18 @@ public class NativeTaskSubmitter implements TaskSubmitter {
|
||||
BaseActor actor, FunctionDescriptor functionDescriptor,
|
||||
List<FunctionArg> args, int numReturns, CallOptions options) {
|
||||
Preconditions.checkState(actor instanceof NativeRayActor);
|
||||
List<byte[]> returnIds = nativeSubmitActorTask(nativeCoreWorkerPointer,
|
||||
actor.getId().getBytes(), functionDescriptor, args, numReturns,
|
||||
options);
|
||||
List<byte[]> returnIds = nativeSubmitActorTask(actor.getId().getBytes(),
|
||||
functionDescriptor, args, numReturns, options);
|
||||
return returnIds.stream().map(ObjectId::new).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static native List<byte[]> nativeSubmitTask(
|
||||
long nativeCoreWorkerPointer,
|
||||
private static native List<byte[]> nativeSubmitTask(FunctionDescriptor functionDescriptor,
|
||||
List<FunctionArg> args, int numReturns, CallOptions callOptions);
|
||||
|
||||
private static native byte[] nativeCreateActor(FunctionDescriptor functionDescriptor,
|
||||
List<FunctionArg> args, ActorCreationOptions actorCreationOptions);
|
||||
|
||||
private static native List<byte[]> nativeSubmitActorTask(byte[] actorId,
|
||||
FunctionDescriptor functionDescriptor, List<FunctionArg> args, int numReturns,
|
||||
CallOptions callOptions);
|
||||
|
||||
private static native byte[] nativeCreateActor(
|
||||
long nativeCoreWorkerPointer,
|
||||
FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
ActorCreationOptions actorCreationOptions);
|
||||
|
||||
private static native List<byte[]> nativeSubmitActorTask(
|
||||
long nativeCoreWorkerPointer,
|
||||
byte[] actorId, FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
int numReturns, CallOptions callOptions);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.lang.reflect.InvocationTargetException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
@@ -9,58 +10,75 @@ import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.RayRuntimeInternal;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.object.ObjectSerializer;
|
||||
import org.ray.runtime.task.TaskExecutor.ActorContext;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* The task executor, which executes tasks assigned by raylet continuously.
|
||||
*/
|
||||
public abstract class TaskExecutor {
|
||||
public abstract class TaskExecutor<T extends ActorContext> {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class);
|
||||
|
||||
// A helper map to help we get the corresponding executor for the given worker in JNI.
|
||||
private static ConcurrentHashMap<UniqueId, TaskExecutor> taskExecutors
|
||||
= new ConcurrentHashMap<>();
|
||||
protected final RayRuntimeInternal runtime;
|
||||
|
||||
protected final AbstractRayRuntime runtime;
|
||||
private final ConcurrentHashMap<UniqueId, T> actorContextMap = new ConcurrentHashMap<>();
|
||||
|
||||
/**
|
||||
* The current actor object, if this worker is an actor, otherwise null.
|
||||
*/
|
||||
protected Object currentActor = null;
|
||||
static class ActorContext {
|
||||
|
||||
/**
|
||||
* The exception that failed the actor creation task, if any.
|
||||
*/
|
||||
private Exception actorCreationException = null;
|
||||
/**
|
||||
* The current actor object, if this worker is an actor, otherwise null.
|
||||
*/
|
||||
Object currentActor = null;
|
||||
|
||||
protected TaskExecutor(AbstractRayRuntime runtime) {
|
||||
this.runtime = runtime;
|
||||
if (RayConfig.getInstance().runMode == RunMode.CLUSTER) {
|
||||
taskExecutors.put(runtime.getWorkerContext().getCurrentWorkerId(), this);
|
||||
}
|
||||
/**
|
||||
* The exception that failed the actor creation task, if any.
|
||||
*/
|
||||
Throwable actorCreationException = null;
|
||||
}
|
||||
|
||||
public static TaskExecutor get(byte[] workerId) {
|
||||
return taskExecutors.get(new UniqueId(workerId));
|
||||
TaskExecutor(RayRuntimeInternal runtime) {
|
||||
this.runtime = runtime;
|
||||
}
|
||||
|
||||
protected abstract T createActorContext();
|
||||
|
||||
T getActorContext() {
|
||||
return actorContextMap.get(runtime.getWorkerContext().getCurrentWorkerId());
|
||||
}
|
||||
|
||||
void setActorContext(T actorContext) {
|
||||
if (actorContext == null) {
|
||||
// ConcurrentHashMap doesn't allow null values. So just return here.
|
||||
return;
|
||||
}
|
||||
this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext);
|
||||
}
|
||||
|
||||
protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
|
||||
List<NativeRayObject> argsBytes) {
|
||||
runtime.setIsContextSet(true);
|
||||
JobId jobId = runtime.getWorkerContext().getCurrentJobId();
|
||||
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
|
||||
TaskId taskId = runtime.getWorkerContext().getCurrentTaskId();
|
||||
LOGGER.debug("Executing task {}", taskId);
|
||||
|
||||
T actorContext = null;
|
||||
if (taskType == TaskType.ACTOR_CREATION_TASK) {
|
||||
actorContext = createActorContext();
|
||||
setActorContext(actorContext);
|
||||
} else if (taskType == TaskType.ACTOR_TASK) {
|
||||
actorContext = getActorContext();
|
||||
Preconditions.checkNotNull(actorContext);
|
||||
}
|
||||
|
||||
List<NativeRayObject> returnObjects = new ArrayList<>();
|
||||
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
|
||||
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
|
||||
@@ -74,19 +92,26 @@ public abstract class TaskExecutor {
|
||||
// Get local actor object and arguments.
|
||||
Object actor = null;
|
||||
if (taskType == TaskType.ACTOR_TASK) {
|
||||
if (actorCreationException != null) {
|
||||
throw actorCreationException;
|
||||
if (actorContext.actorCreationException != null) {
|
||||
throw actorContext.actorCreationException;
|
||||
}
|
||||
actor = currentActor;
|
||||
|
||||
actor = actorContext.currentActor;
|
||||
}
|
||||
Object[] args = ArgumentsBuilder.unwrap(argsBytes, rayFunction.classLoader);
|
||||
// Execute the task.
|
||||
Object result;
|
||||
if (!rayFunction.isConstructor()) {
|
||||
result = rayFunction.getMethod().invoke(actor, args);
|
||||
} else {
|
||||
result = rayFunction.getConstructor().newInstance(args);
|
||||
try {
|
||||
if (!rayFunction.isConstructor()) {
|
||||
result = rayFunction.getMethod().invoke(actor, args);
|
||||
} else {
|
||||
result = rayFunction.getConstructor().newInstance(args);
|
||||
}
|
||||
} catch (InvocationTargetException e) {
|
||||
if (e.getCause() != null) {
|
||||
throw e.getCause();
|
||||
} else {
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
// Set result
|
||||
if (taskType != TaskType.ACTOR_CREATION_TASK) {
|
||||
@@ -100,10 +125,10 @@ public abstract class TaskExecutor {
|
||||
} else {
|
||||
// TODO (kfstorm): handle checkpoint in core worker.
|
||||
maybeLoadCheckpoint(result, runtime.getWorkerContext().getCurrentActorId());
|
||||
currentActor = result;
|
||||
actorContext.currentActor = result;
|
||||
}
|
||||
LOGGER.debug("Finished executing task {}", taskId);
|
||||
} catch (Exception e) {
|
||||
} catch (Throwable e) {
|
||||
LOGGER.error("Error executing task " + taskId, e);
|
||||
if (taskType != TaskType.ACTOR_CREATION_TASK) {
|
||||
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
|
||||
@@ -113,11 +138,12 @@ public abstract class TaskExecutor {
|
||||
.serialize(new RayTaskException("Error executing task " + taskId, e)));
|
||||
}
|
||||
} else {
|
||||
actorCreationException = e;
|
||||
actorContext.actorCreationException = e;
|
||||
}
|
||||
} finally {
|
||||
Thread.currentThread().setContextClassLoader(oldLoader);
|
||||
runtime.getWorkerContext().setCurrentClassLoader(null);
|
||||
runtime.setIsContextSet(false);
|
||||
}
|
||||
return returnObjects;
|
||||
}
|
||||
|
||||
@@ -98,12 +98,6 @@ ray {
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------
|
||||
// configurations under SINGLE_PROCESS mode
|
||||
// ----------------------------
|
||||
dev-runtime {
|
||||
// Number of threads that you process tasks
|
||||
execution-parallelism: 10
|
||||
}
|
||||
|
||||
// Whether we enable job manager to submit and manage job.
|
||||
enable-job-manager: false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user