diff --git a/BUILD.bazel b/BUILD.bazel index fe64d9a5b..d26295b79 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -74,7 +74,10 @@ python_grpc_compile( proto_library( name = "gcs_service_proto", srcs = ["src/ray/protobuf/gcs_service.proto"], - deps = [":gcs_proto"], + deps = [ + ":common_proto", + ":gcs_proto", + ], ) cc_proto_library( diff --git a/cpp/src/ray/runtime/local_mode_ray_runtime.cc b/cpp/src/ray/runtime/local_mode_ray_runtime.cc index 898f3cad4..113f7b148 100644 --- a/cpp/src/ray/runtime/local_mode_ray_runtime.cc +++ b/cpp/src/ray/runtime/local_mode_ray_runtime.cc @@ -12,8 +12,8 @@ namespace api { LocalModeRayRuntime::LocalModeRayRuntime(std::shared_ptr config) { config_ = config; - worker_ = - std::unique_ptr(new WorkerContext(WorkerType::DRIVER, JobID::Nil())); + worker_ = std::unique_ptr(new WorkerContext( + WorkerType::DRIVER, ComputeDriverIdFromJob(JobID::Nil()), JobID::Nil())); object_store_ = std::unique_ptr(new LocalModeObjectStore(*this)); task_submitter_ = std::unique_ptr(new LocalModeTaskSubmitter(*this)); } diff --git a/java/api/src/main/java/org/ray/api/Ray.java b/java/api/src/main/java/org/ray/api/Ray.java index f66be6c73..cf673b13a 100644 --- a/java/api/src/main/java/org/ray/api/Ray.java +++ b/java/api/src/main/java/org/ray/api/Ray.java @@ -44,7 +44,7 @@ public final class Ray extends RayCall { /** * Shutdown Ray runtime. */ - public static void shutdown() { + public static synchronized void shutdown() { if (runtime != null) { runtime.shutdown(); runtime = null; @@ -137,6 +137,11 @@ public final class Ray extends RayCall { runtime.setAsyncContext(asyncContext); } + // TODO (kfstorm): add the `rollbackAsyncContext` API to allow rollbacking the async context of + // the current thread to the one before `setAsyncContext` is called. + + // TODO (kfstorm): unify the `wrap*` methods. + /** * If users want to use Ray API in their own threads, they should wrap their {@link Runnable} * objects with this method. @@ -155,7 +160,7 @@ public final class Ray extends RayCall { * @param callable The callable to wrap. * @return The wrapped callable. */ - public static Callable wrapCallable(Callable callable) { + public static Callable wrapCallable(Callable callable) { return runtime.wrapCallable(callable); } diff --git a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java index 308f135b9..9f8552e74 100644 --- a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java @@ -159,6 +159,7 @@ public interface RayRuntime { /** * Wrap a {@link Runnable} with necessary context capture. + * * @param runnable The runnable to wrap. * @return The wrapped runnable. */ @@ -166,8 +167,9 @@ public interface RayRuntime { /** * Wrap a {@link Callable} with necessary context capture. + * * @param callable The callable to wrap. * @return The wrapped callable. */ - Callable wrapCallable(Callable callable); + Callable wrapCallable(Callable callable); } diff --git a/java/generate_jni_header_files.sh b/java/generate_jni_header_files.sh new file mode 100755 index 000000000..7e8c73e1b --- /dev/null +++ b/java/generate_jni_header_files.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash + +set -e +set -x + +cd "$(dirname "$0")" + +(cd .. && bazel build //java:all_tests_deploy.jar) + +function generate_one() +{ + file=${1//./_}.h + javah -classpath ../bazel-bin/java/all_tests_deploy.jar $1 + clang-format -i $file + + cat < ../src/ray/core_worker/lib/java/$file +// Copyright 2017 The Ray Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +EOF + cat $file >> ../src/ray/core_worker/lib/java/$file + rm -f $file +} + +generate_one org.ray.runtime.RayNativeRuntime +generate_one org.ray.runtime.task.NativeTaskSubmitter +generate_one org.ray.runtime.context.NativeWorkerContext +generate_one org.ray.runtime.actor.NativeRayActor +generate_one org.ray.runtime.object.NativeObjectStore +generate_one org.ray.runtime.task.NativeTaskExecutor + +# Remove empty files +rm -f org_ray_runtime_RayNativeRuntime_AsyncContext.h +rm -f org_ray_runtime_task_NativeTaskExecutor_NativeActorContext.h \ No newline at end of file diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index ffdf7992d..44c8bd6dd 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -19,7 +19,6 @@ import org.ray.api.function.RayFuncVoid; import org.ray.api.id.ObjectId; import org.ray.api.options.ActorCreationOptions; import org.ray.api.options.CallOptions; -import org.ray.api.runtime.RayRuntime; import org.ray.api.runtimecontext.RuntimeContext; import org.ray.runtime.config.RayConfig; import org.ray.runtime.context.RuntimeContextImpl; @@ -29,6 +28,7 @@ import org.ray.runtime.functionmanager.FunctionManager; import org.ray.runtime.functionmanager.PyFunctionDescriptor; import org.ray.runtime.gcs.GcsClient; import org.ray.runtime.generated.Common.Language; +import org.ray.runtime.generated.Common.WorkerType; import org.ray.runtime.object.ObjectStore; import org.ray.runtime.object.RayObjectImpl; import org.ray.runtime.task.ArgumentsBuilder; @@ -41,7 +41,7 @@ import org.slf4j.LoggerFactory; /** * Core functionality to implement Ray APIs. */ -public abstract class AbstractRayRuntime implements RayRuntime { +public abstract class AbstractRayRuntime implements RayRuntimeInternal { private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class); public static final String PYTHON_INIT_METHOD_NAME = "__init__"; @@ -55,9 +55,15 @@ public abstract class AbstractRayRuntime implements RayRuntime { protected TaskSubmitter taskSubmitter; protected WorkerContext workerContext; - public AbstractRayRuntime(RayConfig rayConfig, FunctionManager functionManager) { + /** + * Whether the required thread context is set on the current thread. + */ + final ThreadLocal isContextSet = ThreadLocal.withInitial(() -> false); + + public AbstractRayRuntime(RayConfig rayConfig) { this.rayConfig = rayConfig; - this.functionManager = functionManager; + setIsContextSet(rayConfig.workerMode == WorkerType.DRIVER); + functionManager = new FunctionManager(rayConfig.jobResourcePath); runtimeContext = new RuntimeContextImpl(this); } @@ -161,13 +167,54 @@ public abstract class AbstractRayRuntime implements RayRuntime { } @Override - public Runnable wrapRunnable(Runnable runnable) { - return runnable; + public void setAsyncContext(Object asyncContext) { + isContextSet.set(true); + } + + // TODO (kfstorm): Simplify the duplicate code in wrap*** methods. + + @Override + public final Runnable wrapRunnable(Runnable runnable) { + Object asyncContext = getAsyncContext(); + return () -> { + boolean oldIsContextSet = isContextSet.get(); + Object oldAsyncContext = null; + if (oldIsContextSet) { + oldAsyncContext = getAsyncContext(); + } + setAsyncContext(asyncContext); + try { + runnable.run(); + } finally { + if (oldIsContextSet) { + setAsyncContext(oldAsyncContext); + } else { + setIsContextSet(false); + } + } + }; } @Override - public Callable wrapCallable(Callable callable) { - return callable; + public final Callable wrapCallable(Callable callable) { + Object asyncContext = getAsyncContext(); + return () -> { + boolean oldIsContextSet = isContextSet.get(); + Object oldAsyncContext = null; + if (oldIsContextSet) { + oldAsyncContext = getAsyncContext(); + } + setAsyncContext(asyncContext); + try { + return callable.call(); + } finally { + if (oldIsContextSet) { + setAsyncContext(oldAsyncContext); + } else { + setIsContextSet(false); + } + } + }; } private RayObject callNormalFunction(FunctionDescriptor functionDescriptor, @@ -209,18 +256,22 @@ public abstract class AbstractRayRuntime implements RayRuntime { return actor; } + @Override public WorkerContext getWorkerContext() { return workerContext; } + @Override public ObjectStore getObjectStore() { return objectStore; } + @Override public FunctionManager getFunctionManager() { return functionManager; } + @Override public RayConfig getRayConfig() { return rayConfig; } @@ -229,7 +280,13 @@ public abstract class AbstractRayRuntime implements RayRuntime { return runtimeContext; } + @Override public GcsClient getGcsClient() { return gcsClient; } + + @Override + public void setIsContextSet(boolean isContextSet) { + this.isContextSet.set(isContextSet); + } } diff --git a/java/runtime/src/main/java/org/ray/runtime/DefaultRayRuntimeFactory.java b/java/runtime/src/main/java/org/ray/runtime/DefaultRayRuntimeFactory.java index 69448ecb7..02e2c93be 100644 --- a/java/runtime/src/main/java/org/ray/runtime/DefaultRayRuntimeFactory.java +++ b/java/runtime/src/main/java/org/ray/runtime/DefaultRayRuntimeFactory.java @@ -4,8 +4,6 @@ import org.ray.api.runtime.RayRuntime; import org.ray.api.runtime.RayRuntimeFactory; import org.ray.runtime.config.RayConfig; import org.ray.runtime.config.RunMode; -import org.ray.runtime.functionmanager.FunctionManager; -import org.ray.runtime.generated.Common.WorkerType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -20,17 +18,13 @@ public class DefaultRayRuntimeFactory implements RayRuntimeFactory { public RayRuntime createRayRuntime() { RayConfig rayConfig = RayConfig.getInstance(); try { - FunctionManager functionManager = new FunctionManager(rayConfig.jobResourcePath); - RayRuntime runtime; - if (rayConfig.runMode == RunMode.SINGLE_PROCESS) { - runtime = new RayDevRuntime(rayConfig, functionManager); - } else { - if (rayConfig.workerMode == WorkerType.DRIVER) { - runtime = new RayNativeRuntime(rayConfig, functionManager); - } else { - runtime = new RayMultiWorkerNativeRuntime(rayConfig, functionManager); - } - } + AbstractRayRuntime innerRuntime = rayConfig.runMode == RunMode.SINGLE_PROCESS + ? new RayDevRuntime(rayConfig) + : new RayNativeRuntime(rayConfig); + RayRuntimeInternal runtime = rayConfig.numWorkersPerProcess > 1 + ? RayRuntimeProxy.newInstance(innerRuntime) + : innerRuntime; + runtime.start(); return runtime; } catch (Exception e) { LOGGER.error("Failed to initialize ray runtime", e); diff --git a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java index 8f9d74e5b..a46cafda7 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java @@ -1,12 +1,12 @@ package org.ray.runtime; +import com.google.common.base.Preconditions; import java.util.concurrent.atomic.AtomicInteger; import org.ray.api.BaseActor; import org.ray.api.id.JobId; import org.ray.api.id.UniqueId; import org.ray.runtime.config.RayConfig; import org.ray.runtime.context.LocalModeWorkerContext; -import org.ray.runtime.functionmanager.FunctionManager; import org.ray.runtime.object.LocalModeObjectStore; import org.ray.runtime.task.LocalModeTaskExecutor; import org.ray.runtime.task.LocalModeTaskSubmitter; @@ -19,18 +19,31 @@ public class RayDevRuntime extends AbstractRayRuntime { private AtomicInteger jobCounter = new AtomicInteger(0); - public RayDevRuntime(RayConfig rayConfig, FunctionManager functionManager) { - super(rayConfig, functionManager); + public RayDevRuntime(RayConfig rayConfig) { + super(rayConfig); + } + + @Override + public void start() { if (rayConfig.getJobId().isNil()) { rayConfig.setJobId(nextJobId()); } taskExecutor = new LocalModeTaskExecutor(this); workerContext = new LocalModeWorkerContext(rayConfig.getJobId()); objectStore = new LocalModeObjectStore(workerContext); - taskSubmitter = new LocalModeTaskSubmitter(this, (LocalModeObjectStore) objectStore, - rayConfig.numberExecThreadsForDevRuntime); + taskSubmitter = new LocalModeTaskSubmitter(this, taskExecutor, + (LocalModeObjectStore) objectStore); ((LocalModeObjectStore) objectStore).addObjectPutCallback( - objectId -> ((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId)); + objectId -> { + if (taskSubmitter != null) { + ((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId); + } + }); + } + + @Override + public void run() { + throw new UnsupportedOperationException(); } @Override @@ -60,6 +73,8 @@ public class RayDevRuntime extends AbstractRayRuntime { @Override public void setAsyncContext(Object asyncContext) { + Preconditions.checkArgument(asyncContext == null); + super.setAsyncContext(asyncContext); } private JobId nextJobId() { diff --git a/java/runtime/src/main/java/org/ray/runtime/RayMultiWorkerNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayMultiWorkerNativeRuntime.java deleted file mode 100644 index 045da3a31..000000000 --- a/java/runtime/src/main/java/org/ray/runtime/RayMultiWorkerNativeRuntime.java +++ /dev/null @@ -1,215 +0,0 @@ -package org.ray.runtime; - -import com.google.common.base.Preconditions; -import java.util.List; -import java.util.concurrent.Callable; -import org.ray.api.BaseActor; -import org.ray.api.RayActor; -import org.ray.api.RayObject; -import org.ray.api.RayPyActor; -import org.ray.api.WaitResult; -import org.ray.api.function.PyActorClass; -import org.ray.api.function.PyActorMethod; -import org.ray.api.function.PyRemoteFunction; -import org.ray.api.function.RayFunc; -import org.ray.api.id.ObjectId; -import org.ray.api.id.UniqueId; -import org.ray.api.options.ActorCreationOptions; -import org.ray.api.options.CallOptions; -import org.ray.api.runtime.RayRuntime; -import org.ray.api.runtimecontext.RuntimeContext; -import org.ray.runtime.config.RayConfig; -import org.ray.runtime.config.RunMode; -import org.ray.runtime.functionmanager.FunctionManager; -import org.ray.runtime.generated.Common.WorkerType; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * This is a proxy runtime for multi-worker support. It holds multiple {@link RayNativeRuntime} - * instances and redirect calls to the correct one based on thread context. - */ -public class RayMultiWorkerNativeRuntime implements RayRuntime { - - private static final Logger LOGGER = LoggerFactory.getLogger(RayMultiWorkerNativeRuntime.class); - - private final FunctionManager functionManager; - - /** - * The number of workers per worker process. - */ - private final int numWorkers; - /** - * The worker threads. - */ - private final Thread[] threads; - /** - * The {@link RayNativeRuntime} instances of workers. - */ - private final RayNativeRuntime[] runtimes; - /** - * The {@link RayNativeRuntime} instance of current thread. - */ - private final ThreadLocal currentThreadRuntime = new ThreadLocal<>(); - - public RayMultiWorkerNativeRuntime(RayConfig rayConfig, FunctionManager functionManager) { - this.functionManager = functionManager; - Preconditions.checkState( - rayConfig.runMode == RunMode.CLUSTER && rayConfig.workerMode == WorkerType.WORKER); - Preconditions.checkState(rayConfig.numWorkersPerProcess > 0, - "numWorkersPerProcess must be greater than 0."); - numWorkers = rayConfig.numWorkersPerProcess; - runtimes = new RayNativeRuntime[numWorkers]; - threads = new Thread[numWorkers]; - - LOGGER.info("Starting {} workers.", numWorkers); - - for (int i = 0; i < numWorkers; i++) { - final int workerIndex = i; - threads[i] = new Thread(() -> { - RayNativeRuntime runtime = new RayNativeRuntime(rayConfig, functionManager); - runtimes[workerIndex] = runtime; - currentThreadRuntime.set(runtime); - runtime.run(); - }); - } - } - - public void run() { - for (int i = 0; i < numWorkers; i++) { - threads[i].start(); - } - for (int i = 0; i < numWorkers; i++) { - try { - threads[i].join(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } - } - - @Override - public void shutdown() { - for (int i = 0; i < numWorkers; i++) { - runtimes[i].shutdown(); - } - for (int i = 0; i < numWorkers; i++) { - try { - threads[i].join(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } - } - - public RayNativeRuntime getCurrentRuntime() { - RayNativeRuntime currentRuntime = currentThreadRuntime.get(); - Preconditions.checkNotNull(currentRuntime, - "RayRuntime is not set on current thread." - + " If you want to use Ray API in your own threads," - + " please wrap your `Runnable`s or `Callable`s with" - + " `Ray.wrapRunnable` or `Ray.wrapCallable`."); - return currentRuntime; - } - - @Override - public RayObject put(T obj) { - return getCurrentRuntime().put(obj); - } - - @Override - public T get(ObjectId objectId) { - return getCurrentRuntime().get(objectId); - } - - @Override - public List get(List objectIds) { - return getCurrentRuntime().get(objectIds); - } - - @Override - public WaitResult wait(List> waitList, int numReturns, int timeoutMs) { - return getCurrentRuntime().wait(waitList, numReturns, timeoutMs); - } - - @Override - public void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { - getCurrentRuntime().free(objectIds, localOnly, deleteCreatingTasks); - } - - @Override - public void setResource(String resourceName, double capacity, UniqueId nodeId) { - getCurrentRuntime().setResource(resourceName, capacity, nodeId); - } - - @Override - public void killActor(BaseActor actor, boolean noReconstruction) { - getCurrentRuntime().killActor(actor, noReconstruction); - } - - @Override - public RayObject call(RayFunc func, Object[] args, CallOptions options) { - return getCurrentRuntime().call(func, args, options); - } - - @Override - public RayObject call(PyRemoteFunction pyRemoteFunction, Object[] args, - CallOptions options) { - return getCurrentRuntime().call(pyRemoteFunction, args, options); - } - - @Override - public RayObject callActor(RayActor actor, RayFunc func, Object[] args) { - return getCurrentRuntime().callActor(actor, func, args); - } - - @Override - public RayObject callActor(RayPyActor pyActor, PyActorMethod pyActorMethod, Object[] args) { - return getCurrentRuntime().callActor(pyActor, pyActorMethod, args); - } - - @Override - public RayActor createActor(RayFunc actorFactoryFunc, Object[] args, - ActorCreationOptions options) { - return getCurrentRuntime().createActor(actorFactoryFunc, args, options); - } - - @Override - public RayPyActor createActor(PyActorClass pyActorClass, Object[] args, - ActorCreationOptions options) { - return getCurrentRuntime().createActor(pyActorClass, args, options); - } - - @Override - public RuntimeContext getRuntimeContext() { - return getCurrentRuntime().getRuntimeContext(); - } - - @Override - public Object getAsyncContext() { - return getCurrentRuntime(); - } - - @Override - public void setAsyncContext(Object asyncContext) { - currentThreadRuntime.set((RayNativeRuntime) asyncContext); - } - - @Override - public Runnable wrapRunnable(Runnable runnable) { - Object asyncContext = getAsyncContext(); - return () -> { - setAsyncContext(asyncContext); - runnable.run(); - }; - } - - @Override - public Callable wrapCallable(Callable callable) { - Object asyncContext = getAsyncContext(); - return () -> { - setAsyncContext(asyncContext); - return callable.call(); - }; - } -} diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index 6448b824e..67fef8906 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -3,7 +3,6 @@ package org.ray.runtime; import com.google.common.base.Preconditions; import java.io.File; import java.io.IOException; -import java.util.HashMap; import java.util.Map; import org.apache.commons.io.FileUtils; import org.ray.api.BaseActor; @@ -11,7 +10,6 @@ import org.ray.api.id.JobId; import org.ray.api.id.UniqueId; import org.ray.runtime.config.RayConfig; import org.ray.runtime.context.NativeWorkerContext; -import org.ray.runtime.functionmanager.FunctionManager; import org.ray.runtime.gcs.GcsClient; import org.ray.runtime.gcs.GcsClientOptions; import org.ray.runtime.gcs.RedisClient; @@ -34,10 +32,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime { private RunManager manager = null; - /** - * The native pointer of core worker. - */ - private long nativeCoreWorkerPointer; static { LOGGER.debug("Loading native libraries."); @@ -55,14 +49,13 @@ public final class RayNativeRuntime extends AbstractRayRuntime { JniUtils.loadLibrary("core_worker_library_java", true); LOGGER.debug("Native libraries loaded."); + // Reset library path at runtime. resetLibraryPath(rayConfig); try { FileUtils.forceMkdir(new File(rayConfig.logDir)); } catch (IOException e) { throw new RuntimeException("Failed to create the log directory.", e); } - nativeSetup(rayConfig.logDir, rayConfig.rayletConfigParameters); - Runtime.getRuntime().addShutdownHook(new Thread(RayNativeRuntime::nativeShutdownHook)); } private static void resetLibraryPath(RayConfig rayConfig) { @@ -71,11 +64,12 @@ public final class RayNativeRuntime extends AbstractRayRuntime { JniUtils.resetLibraryPath(libraryPath); } - public RayNativeRuntime(RayConfig rayConfig, FunctionManager functionManager) { - super(rayConfig, functionManager); - // Reset library path at runtime. - resetLibraryPath(rayConfig); + public RayNativeRuntime(RayConfig rayConfig) { + super(rayConfig); + } + @Override + public void start() { if (rayConfig.getRedisAddress() == null) { manager = new RunManager(rayConfig); manager.startRayProcesses(true); @@ -86,21 +80,21 @@ public final class RayNativeRuntime extends AbstractRayRuntime { if (rayConfig.getJobId() == JobId.NIL) { rayConfig.setJobId(gcsClient.nextJobId()); } + int numWorkersPerProcess = + rayConfig.workerMode == WorkerType.DRIVER ? 1 : rayConfig.numWorkersPerProcess; // TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis. - nativeCoreWorkerPointer = nativeInitCoreWorker(rayConfig.workerMode.getNumber(), - rayConfig.objectStoreSocketName, rayConfig.rayletSocketName, + nativeInitialize(rayConfig.workerMode.getNumber(), rayConfig.nodeIp, rayConfig.getNodeManagerPort(), + rayConfig.workerMode == WorkerType.DRIVER ? System.getProperty("user.dir") : "", + rayConfig.objectStoreSocketName, rayConfig.rayletSocketName, (rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(), - new GcsClientOptions(rayConfig)); - Preconditions.checkState(nativeCoreWorkerPointer != 0); + new GcsClientOptions(rayConfig), numWorkersPerProcess, + rayConfig.logDir, rayConfig.rayletConfigParameters); - workerContext = new NativeWorkerContext(nativeCoreWorkerPointer); - taskExecutor = new NativeTaskExecutor(nativeCoreWorkerPointer, this); - objectStore = new NativeObjectStore(workerContext, nativeCoreWorkerPointer); - taskSubmitter = new NativeTaskSubmitter(nativeCoreWorkerPointer); - - // register - registerWorker(); + taskExecutor = new NativeTaskExecutor(this); + workerContext = new NativeWorkerContext(); + objectStore = new NativeObjectStore(workerContext); + taskSubmitter = new NativeTaskSubmitter(); LOGGER.info("RayNativeRuntime started with store {}, raylet {}", rayConfig.objectStoreSocketName, rayConfig.rayletSocketName); @@ -108,10 +102,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { @Override public void shutdown() { - if (nativeCoreWorkerPointer != 0) { - nativeDestroyCoreWorker(nativeCoreWorkerPointer); - nativeCoreWorkerPointer = 0; - } + nativeShutdown(); if (null != manager) { manager.cleanup(); manager = null; @@ -131,75 +122,56 @@ public final class RayNativeRuntime extends AbstractRayRuntime { if (nodeId == null) { nodeId = UniqueId.NIL; } - nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes()); + nativeSetResource(resourceName, capacity, nodeId.getBytes()); } @Override public void killActor(BaseActor actor, boolean noReconstruction) { - nativeKillActor(nativeCoreWorkerPointer, actor.getId().getBytes(), noReconstruction); + nativeKillActor(actor.getId().getBytes(), noReconstruction); } @Override public Object getAsyncContext() { - return null; + return new AsyncContext(workerContext.getCurrentWorkerId(), + workerContext.getCurrentClassLoader()); } @Override public void setAsyncContext(Object asyncContext) { + nativeSetCoreWorker(((AsyncContext) asyncContext).workerId.getBytes()); + workerContext.setCurrentClassLoader(((AsyncContext) asyncContext).currentClassLoader); + super.setAsyncContext(asyncContext); } + @Override public void run() { - nativeRunTaskExecutor(nativeCoreWorkerPointer); + Preconditions.checkState(rayConfig.workerMode == WorkerType.WORKER); + nativeRunTaskExecutor(taskExecutor); } - public long getNativeCoreWorkerPointer() { - return nativeCoreWorkerPointer; - } + private static native void nativeInitialize(int workerMode, String ndoeIpAddress, + int nodeManagerPort, String driverName, String storeSocket, String rayletSocket, + byte[] jobId, GcsClientOptions gcsClientOptions, int numWorkersPerProcess, + String logDir, Map rayletConfigParameters); - public TaskExecutor getTaskExecutor() { - return taskExecutor; - } + private static native void nativeRunTaskExecutor(TaskExecutor taskExecutor); - /** - * Register this worker or driver to GCS. - */ - private void registerWorker() { - RedisClient redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword); - Map workerInfo = new HashMap<>(); - String workerId = new String(workerContext.getCurrentWorkerId().getBytes()); - if (rayConfig.workerMode == WorkerType.DRIVER) { - workerInfo.put("node_ip_address", rayConfig.nodeIp); - workerInfo.put("driver_id", workerId); - workerInfo.put("start_time", String.valueOf(System.currentTimeMillis())); - workerInfo.put("plasma_store_socket", rayConfig.objectStoreSocketName); - workerInfo.put("raylet_socket", rayConfig.rayletSocketName); - workerInfo.put("name", System.getProperty("user.dir")); - //TODO: worker.redis_client.hmset(b"Drivers:" + worker.workerId, driver_info) - redisClient.hmset("Drivers:" + workerId, workerInfo); - } else { - workerInfo.put("node_ip_address", rayConfig.nodeIp); - workerInfo.put("plasma_store_socket", rayConfig.objectStoreSocketName); - workerInfo.put("raylet_socket", rayConfig.rayletSocketName); - //TODO: b"Workers:" + worker.workerId, - redisClient.hmset("Workers:" + workerId, workerInfo); + private static native void nativeShutdown(); + + private static native void nativeSetResource(String resourceName, double capacity, byte[] nodeId); + + private static native void nativeKillActor(byte[] actorId, boolean noReconstruction); + + private static native void nativeSetCoreWorker(byte[] workerId); + + static class AsyncContext { + + public final UniqueId workerId; + public final ClassLoader currentClassLoader; + + AsyncContext(UniqueId workerId, ClassLoader currentClassLoader) { + this.workerId = workerId; + this.currentClassLoader = currentClassLoader; } } - - private static native long nativeInitCoreWorker(int workerMode, String storeSocket, - String rayletSocket, String nodeIpAddress, int nodeManagerPort, byte[] jobId, - GcsClientOptions gcsClientOptions); - - private static native void nativeRunTaskExecutor(long nativeCoreWorkerPointer); - - private static native void nativeDestroyCoreWorker(long nativeCoreWorkerPointer); - - private static native void nativeSetup(String logDir, Map rayletConfigParameters); - - private static native void nativeShutdownHook(); - - private static native void nativeSetResource(long conn, String resourceName, double capacity, - byte[] nodeId); - - private static native void nativeKillActor(long nativeCoreWorkerPointer, byte[] actorId, - boolean noReconstruction); } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayRuntimeInternal.java b/java/runtime/src/main/java/org/ray/runtime/RayRuntimeInternal.java new file mode 100644 index 000000000..a8a56c132 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/RayRuntimeInternal.java @@ -0,0 +1,33 @@ +package org.ray.runtime; + +import org.ray.api.runtime.RayRuntime; +import org.ray.runtime.config.RayConfig; +import org.ray.runtime.context.WorkerContext; +import org.ray.runtime.functionmanager.FunctionManager; +import org.ray.runtime.gcs.GcsClient; +import org.ray.runtime.object.ObjectStore; + +/** + * This interface is required to make {@link RayRuntimeProxy} work. + */ +public interface RayRuntimeInternal extends RayRuntime { + + /** + * Start runtime. + */ + void start(); + + WorkerContext getWorkerContext(); + + ObjectStore getObjectStore(); + + FunctionManager getFunctionManager(); + + RayConfig getRayConfig(); + + GcsClient getGcsClient(); + + void setIsContextSet(boolean isContextSet); + + void run(); +} diff --git a/java/runtime/src/main/java/org/ray/runtime/RayRuntimeProxy.java b/java/runtime/src/main/java/org/ray/runtime/RayRuntimeProxy.java new file mode 100644 index 000000000..ee1957add --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/RayRuntimeProxy.java @@ -0,0 +1,83 @@ +package org.ray.runtime; + +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import org.ray.api.exception.RayException; +import org.ray.api.runtime.RayRuntime; +import org.ray.runtime.config.RunMode; + +/** + * Protect a ray runtime with context checks for all methods of {@link RayRuntime} (except {@link + * RayRuntime#shutdown(boolean)}). + */ +public class RayRuntimeProxy implements InvocationHandler { + + /** + * The original runtime. + */ + private AbstractRayRuntime obj; + + private RayRuntimeProxy(AbstractRayRuntime obj) { + this.obj = obj; + } + + public AbstractRayRuntime getRuntimeObject() { + return obj; + } + + /** + * Generate a new instance of {@link RayRuntimeInternal} with additional context check. + */ + static RayRuntimeInternal newInstance(AbstractRayRuntime obj) { + return (RayRuntimeInternal) java.lang.reflect.Proxy + .newProxyInstance(obj.getClass().getClassLoader(), new Class[]{RayRuntimeInternal.class}, + new RayRuntimeProxy(obj)); + } + + @Override + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + if (isInterfaceMethod(method) && !method.getName().equals("shutdown") && !method.getName() + .equals("setAsyncContext")) { + checkIsContextSet(); + } + try { + return method.invoke(obj, args); + } catch (InvocationTargetException e) { + if (e.getCause() != null) { + throw e.getCause(); + } else { + throw e; + } + } + } + + /** + * Whether the method is defined in the {@link RayRuntime} interface. + */ + private boolean isInterfaceMethod(Method method) { + try { + RayRuntime.class.getMethod(method.getName(), method.getParameterTypes()); + return true; + } catch (NoSuchMethodException e) { + return false; + } + } + + /** + * Check if thread context is set. + *

+ * This method should be invoked at the beginning of most public methods of {@link RayRuntime}, + * otherwise the native code might crash due to thread local core worker was not set. We check it + * for {@link AbstractRayRuntime} instead of {@link RayNativeRuntime} because we want to catch the + * error even if the application runs in {@link RunMode#SINGLE_PROCESS} mode. + */ + private void checkIsContextSet() { + if (!obj.isContextSet.get()) { + throw new RayException( + "`Ray.wrap***` is not called on the current thread." + + " If you want to use Ray API in your own threads," + + " please wrap your executable with `Ray.wrap***`."); + } + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java index f7ebec3e7..21b99b783 100644 --- a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java +++ b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayActor.java @@ -7,11 +7,7 @@ import java.io.ObjectInput; import java.io.ObjectOutput; import java.util.List; import org.ray.api.BaseActor; -import org.ray.api.Ray; import org.ray.api.id.ActorId; -import org.ray.api.runtime.RayRuntime; -import org.ray.runtime.RayMultiWorkerNativeRuntime; -import org.ray.runtime.RayNativeRuntime; import org.ray.runtime.generated.Common.Language; /** @@ -20,20 +16,17 @@ import org.ray.runtime.generated.Common.Language; */ public abstract class NativeRayActor implements BaseActor, Externalizable { - /** - * Address of core worker. - */ - long nativeCoreWorkerPointer; /** * ID of the actor. */ byte[] actorId; - NativeRayActor(long nativeCoreWorkerPointer, byte[] actorId) { - Preconditions.checkState(nativeCoreWorkerPointer != 0); + private Language language; + + NativeRayActor(byte[] actorId, Language language) { Preconditions.checkState(!ActorId.fromBytes(actorId).isNil()); - this.nativeCoreWorkerPointer = nativeCoreWorkerPointer; this.actorId = actorId; + this.language = language; } /** @@ -42,14 +35,12 @@ public abstract class NativeRayActor implements BaseActor, Externalizable { NativeRayActor() { } - public static NativeRayActor create(long nativeCoreWorkerPointer, byte[] actorId, - Language language) { - Preconditions.checkState(nativeCoreWorkerPointer != 0); + public static NativeRayActor create(byte[] actorId, Language language) { switch (language) { case JAVA: - return new NativeRayJavaActor(nativeCoreWorkerPointer, actorId); + return new NativeRayJavaActor(actorId); case PYTHON: - return new NativeRayPyActor(nativeCoreWorkerPointer, actorId); + return new NativeRayPyActor(actorId); default: throw new IllegalStateException("Unknown actor handle language: " + language); } @@ -61,18 +52,19 @@ public abstract class NativeRayActor implements BaseActor, Externalizable { } public Language getLanguage() { - return Language.forNumber(nativeGetLanguage(nativeCoreWorkerPointer, actorId)); + return language; } @Override public void writeExternal(ObjectOutput out) throws IOException { - out.writeObject(toBytes()); + out.writeObject(nativeSerialize(actorId)); + out.writeObject(language); } @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - nativeCoreWorkerPointer = getNativeCoreWorkerPointer(); - actorId = nativeDeserialize(nativeCoreWorkerPointer, (byte[]) in.readObject()); + actorId = nativeDeserialize((byte[]) in.readObject()); + language = (Language) in.readObject(); } /** @@ -81,7 +73,7 @@ public abstract class NativeRayActor implements BaseActor, Externalizable { * @return the bytes of the actor handle */ public byte[] toBytes() { - return nativeSerialize(nativeCoreWorkerPointer, actorId); + return nativeSerialize(actorId); } /** @@ -90,21 +82,10 @@ public abstract class NativeRayActor implements BaseActor, Externalizable { * @return the bytes of an actor handle */ public static NativeRayActor fromBytes(byte[] bytes) { - long nativeCoreWorkerPointer = getNativeCoreWorkerPointer(); - byte[] actorId = nativeDeserialize(nativeCoreWorkerPointer, bytes); - Language language = Language.forNumber(nativeGetLanguage(nativeCoreWorkerPointer, actorId)); + byte[] actorId = nativeDeserialize(bytes); + Language language = Language.forNumber(nativeGetLanguage(actorId)); Preconditions.checkNotNull(language); - return create(nativeCoreWorkerPointer, actorId, language); - } - - private static long getNativeCoreWorkerPointer() { - RayRuntime runtime = Ray.internal(); - if (runtime instanceof RayMultiWorkerNativeRuntime) { - runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime(); - } - Preconditions.checkState(runtime instanceof RayNativeRuntime); - - return ((RayNativeRuntime) runtime).getNativeCoreWorkerPointer(); + return create(actorId, language); } @Override @@ -112,13 +93,11 @@ public abstract class NativeRayActor implements BaseActor, Externalizable { // TODO(zhijunfu): do we need to free the ActorHandle in core worker? } - private static native int nativeGetLanguage( - long nativeCoreWorkerPointer, byte[] actorId); + private static native int nativeGetLanguage(byte[] actorId); - static native List nativeGetActorCreationTaskFunctionDescriptor( - long nativeCoreWorkerPointer, byte[] actorId); + static native List nativeGetActorCreationTaskFunctionDescriptor(byte[] actorId); - private static native byte[] nativeSerialize(long nativeCoreWorkerPointer, byte[] actorId); + private static native byte[] nativeSerialize(byte[] actorId); - private static native byte[] nativeDeserialize(long nativeCoreWorkerPointer, byte[] data); + private static native byte[] nativeDeserialize(byte[] data); } diff --git a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayJavaActor.java b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayJavaActor.java index a6b096e5f..70dbc1a60 100644 --- a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayJavaActor.java +++ b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayJavaActor.java @@ -11,8 +11,8 @@ import org.ray.runtime.generated.Common.Language; */ public class NativeRayJavaActor extends NativeRayActor implements RayActor { - NativeRayJavaActor(long nativeCoreWorkerPointer, byte[] actorId) { - super(nativeCoreWorkerPointer, actorId); + NativeRayJavaActor(byte[] actorId) { + super(actorId, Language.JAVA); } /** diff --git a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayPyActor.java b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayPyActor.java index ac97bb608..3b87fa58c 100644 --- a/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayPyActor.java +++ b/java/runtime/src/main/java/org/ray/runtime/actor/NativeRayPyActor.java @@ -11,8 +11,8 @@ import org.ray.runtime.generated.Common.Language; */ public class NativeRayPyActor extends NativeRayActor implements RayPyActor { - NativeRayPyActor(long nativeCoreWorkerPointer, byte[] actorId) { - super(nativeCoreWorkerPointer, actorId); + NativeRayPyActor(byte[] actorId) { + super(actorId, Language.PYTHON); } /** @@ -24,12 +24,12 @@ public class NativeRayPyActor extends NativeRayActor implements RayPyActor { @Override public String getModuleName() { - return nativeGetActorCreationTaskFunctionDescriptor(nativeCoreWorkerPointer, actorId).get(0); + return nativeGetActorCreationTaskFunctionDescriptor(actorId).get(0); } @Override public String getClassName() { - return nativeGetActorCreationTaskFunctionDescriptor(nativeCoreWorkerPointer, actorId).get(1); + return nativeGetActorCreationTaskFunctionDescriptor(actorId).get(1); } @Override diff --git a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java index 8671ed73d..ed3828141 100644 --- a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java @@ -93,10 +93,6 @@ public class RayConfig { } } - /** - * Number of threads that execute tasks. - */ - public final int numberExecThreadsForDevRuntime; public final int numWorkersPerProcess; @@ -220,9 +216,6 @@ public class RayConfig { jobResourcePath = null; } - // Number of threads that execute tasks. - numberExecThreadsForDevRuntime = config.getInt("ray.dev-runtime.execution-parallelism"); - numWorkersPerProcess = config.getInt("ray.raylet.config.num_workers_per_process_java"); gcsServiceEnabled = System.getenv("RAY_GCS_SERVICE_ENABLED") == null || diff --git a/java/runtime/src/main/java/org/ray/runtime/context/LocalModeWorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/context/LocalModeWorkerContext.java index ada5aa881..a20adf0ef 100644 --- a/java/runtime/src/main/java/org/ray/runtime/context/LocalModeWorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/context/LocalModeWorkerContext.java @@ -16,6 +16,7 @@ public class LocalModeWorkerContext implements WorkerContext { private final JobId jobId; private ThreadLocal currentTask = new ThreadLocal<>(); + private final ThreadLocal currentWorkerId = new ThreadLocal<>(); public LocalModeWorkerContext(JobId jobId) { this.jobId = jobId; @@ -23,7 +24,11 @@ public class LocalModeWorkerContext implements WorkerContext { @Override public UniqueId getCurrentWorkerId() { - throw new UnsupportedOperationException(); + return currentWorkerId.get(); + } + + public void setCurrentWorkerId(UniqueId workerId) { + currentWorkerId.set(workerId); } @Override diff --git a/java/runtime/src/main/java/org/ray/runtime/context/NativeWorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/context/NativeWorkerContext.java index b42a7b234..2eb814898 100644 --- a/java/runtime/src/main/java/org/ray/runtime/context/NativeWorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/context/NativeWorkerContext.java @@ -12,61 +12,52 @@ import org.ray.runtime.generated.Common.TaskType; */ public class NativeWorkerContext implements WorkerContext { - /** - * The native pointer of core worker. - */ - private final long nativeCoreWorkerPointer; - - private ClassLoader currentClassLoader; - - public NativeWorkerContext(long nativeCoreWorkerPointer) { - this.nativeCoreWorkerPointer = nativeCoreWorkerPointer; - } + private final ThreadLocal currentClassLoader = new ThreadLocal<>(); @Override public UniqueId getCurrentWorkerId() { - return UniqueId.fromByteBuffer(nativeGetCurrentWorkerId(nativeCoreWorkerPointer)); + return UniqueId.fromByteBuffer(nativeGetCurrentWorkerId()); } @Override public JobId getCurrentJobId() { - return JobId.fromByteBuffer(nativeGetCurrentJobId(nativeCoreWorkerPointer)); + return JobId.fromByteBuffer(nativeGetCurrentJobId()); } @Override public ActorId getCurrentActorId() { - return ActorId.fromByteBuffer(nativeGetCurrentActorId(nativeCoreWorkerPointer)); + return ActorId.fromByteBuffer(nativeGetCurrentActorId()); } @Override public ClassLoader getCurrentClassLoader() { - return currentClassLoader; + return currentClassLoader.get(); } @Override public void setCurrentClassLoader(ClassLoader currentClassLoader) { - if (this.currentClassLoader != currentClassLoader) { - this.currentClassLoader = currentClassLoader; + if (this.currentClassLoader.get() != currentClassLoader) { + this.currentClassLoader.set(currentClassLoader); } } @Override public TaskType getCurrentTaskType() { - return TaskType.forNumber(nativeGetCurrentTaskType(nativeCoreWorkerPointer)); + return TaskType.forNumber(nativeGetCurrentTaskType()); } @Override public TaskId getCurrentTaskId() { - return TaskId.fromByteBuffer(nativeGetCurrentTaskId(nativeCoreWorkerPointer)); + return TaskId.fromByteBuffer(nativeGetCurrentTaskId()); } - private static native int nativeGetCurrentTaskType(long nativeCoreWorkerPointer); + private static native int nativeGetCurrentTaskType(); - private static native ByteBuffer nativeGetCurrentTaskId(long nativeCoreWorkerPointer); + private static native ByteBuffer nativeGetCurrentTaskId(); - private static native ByteBuffer nativeGetCurrentJobId(long nativeCoreWorkerPointer); + private static native ByteBuffer nativeGetCurrentJobId(); - private static native ByteBuffer nativeGetCurrentWorkerId(long nativeCoreWorkerPointer); + private static native ByteBuffer nativeGetCurrentWorkerId(); - private static native ByteBuffer nativeGetCurrentActorId(long nativeCoreWorkerPointer); + private static native ByteBuffer nativeGetCurrentActorId(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/context/RuntimeContextImpl.java b/java/runtime/src/main/java/org/ray/runtime/context/RuntimeContextImpl.java index 0680128fe..97ad3badb 100644 --- a/java/runtime/src/main/java/org/ray/runtime/context/RuntimeContextImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/context/RuntimeContextImpl.java @@ -6,15 +6,15 @@ import org.ray.api.id.ActorId; import org.ray.api.id.JobId; import org.ray.api.runtimecontext.NodeInfo; import org.ray.api.runtimecontext.RuntimeContext; -import org.ray.runtime.AbstractRayRuntime; +import org.ray.runtime.RayRuntimeInternal; import org.ray.runtime.config.RunMode; import org.ray.runtime.generated.Common.TaskType; public class RuntimeContextImpl implements RuntimeContext { - private AbstractRayRuntime runtime; + private RayRuntimeInternal runtime; - public RuntimeContextImpl(AbstractRayRuntime runtime) { + public RuntimeContextImpl(RayRuntimeInternal runtime) { this.runtime = runtime; } diff --git a/java/runtime/src/main/java/org/ray/runtime/object/NativeObjectStore.java b/java/runtime/src/main/java/org/ray/runtime/object/NativeObjectStore.java index 9c8798f0f..1d8b863cf 100644 --- a/java/runtime/src/main/java/org/ray/runtime/object/NativeObjectStore.java +++ b/java/runtime/src/main/java/org/ray/runtime/object/NativeObjectStore.java @@ -15,56 +15,48 @@ public class NativeObjectStore extends ObjectStore { private static final Logger LOGGER = LoggerFactory.getLogger(NativeObjectStore.class); - /** - * The native pointer of core worker. - */ - private final long nativeCoreWorkerPointer; - - public NativeObjectStore(WorkerContext workerContext, long nativeCoreWorkerPointer) { + public NativeObjectStore(WorkerContext workerContext) { super(workerContext); - this.nativeCoreWorkerPointer = nativeCoreWorkerPointer; } @Override public ObjectId putRaw(NativeRayObject obj) { - return new ObjectId(nativePut(nativeCoreWorkerPointer, obj)); + return new ObjectId(nativePut(obj)); } @Override public void putRaw(NativeRayObject obj, ObjectId objectId) { - nativePut(nativeCoreWorkerPointer, objectId.getBytes(), obj); + nativePut(objectId.getBytes(), obj); } @Override public List getRaw(List objectIds, long timeoutMs) { - return nativeGet(nativeCoreWorkerPointer, toBinaryList(objectIds), timeoutMs); + return nativeGet(toBinaryList(objectIds), timeoutMs); } @Override public List wait(List objectIds, int numObjects, long timeoutMs) { - return nativeWait(nativeCoreWorkerPointer, toBinaryList(objectIds), numObjects, timeoutMs); + return nativeWait(toBinaryList(objectIds), numObjects, timeoutMs); } @Override public void delete(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { - nativeDelete(nativeCoreWorkerPointer, toBinaryList(objectIds), localOnly, deleteCreatingTasks); + nativeDelete(toBinaryList(objectIds), localOnly, deleteCreatingTasks); } private static List toBinaryList(List ids) { return ids.stream().map(BaseId::getBytes).collect(Collectors.toList()); } - private static native byte[] nativePut(long nativeCoreWorkerPointer, NativeRayObject obj); + private static native byte[] nativePut(NativeRayObject obj); - private static native void nativePut(long nativeCoreWorkerPointer, byte[] objectId, - NativeRayObject obj); + private static native void nativePut(byte[] objectId, NativeRayObject obj); - private static native List nativeGet(long nativeCoreWorkerPointer, - List ids, long timeoutMs); + private static native List nativeGet(List ids, long timeoutMs); - private static native List nativeWait(long nativeCoreWorkerPointer, - List objectIds, int numObjects, long timeoutMs); + private static native List nativeWait(List objectIds, int numObjects, + long timeoutMs); - private static native void nativeDelete(long nativeCoreWorkerPointer, List objectIds, - boolean localOnly, boolean deleteCreatingTasks); + private static native void nativeDelete(List objectIds, boolean localOnly, + boolean deleteCreatingTasks); } diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java b/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java index 296fe1935..a99db00c8 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java @@ -1,9 +1,7 @@ package org.ray.runtime.runner.worker; import org.ray.api.Ray; -import org.ray.api.runtime.RayRuntime; -import org.ray.runtime.RayMultiWorkerNativeRuntime; -import org.ray.runtime.RayNativeRuntime; +import org.ray.runtime.RayRuntimeInternal; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -25,14 +23,7 @@ public class DefaultWorker { }); Ray.init(); LOGGER.info("Worker started."); - RayRuntime runtime = Ray.internal(); - if (runtime instanceof RayNativeRuntime) { - ((RayNativeRuntime)runtime).run(); - } else if (runtime instanceof RayMultiWorkerNativeRuntime) { - ((RayMultiWorkerNativeRuntime)runtime).run(); - } else { - throw new RuntimeException("Unknown RayRuntime: " + runtime); - } + ((RayRuntimeInternal) Ray.internal()).run(); } catch (Exception e) { LOGGER.error("Failed to start worker.", e); } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index 750cd08e7..3dae7f5f2 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -6,9 +6,7 @@ import java.util.List; import org.ray.api.Ray; import org.ray.api.RayObject; import org.ray.api.id.ObjectId; -import org.ray.api.runtime.RayRuntime; -import org.ray.runtime.AbstractRayRuntime; -import org.ray.runtime.RayMultiWorkerNativeRuntime; +import org.ray.runtime.RayRuntimeInternal; import org.ray.runtime.generated.Common.Language; import org.ray.runtime.object.NativeRayObject; import org.ray.runtime.object.ObjectSerializer; @@ -43,11 +41,7 @@ public class ArgumentsBuilder { } else { value = ObjectSerializer.serialize(arg); if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) { - RayRuntime runtime = Ray.internal(); - if (runtime instanceof RayMultiWorkerNativeRuntime) { - runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime(); - } - id = ((AbstractRayRuntime) runtime).getObjectStore() + id = ((RayRuntimeInternal) Ray.internal()).getObjectStore() .putRaw(value); value = null; } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskExecutor.java index 24e6f15b9..9486e6168 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskExecutor.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskExecutor.java @@ -1,17 +1,40 @@ package org.ray.runtime.task; import org.ray.api.id.ActorId; -import org.ray.runtime.AbstractRayRuntime; +import org.ray.api.id.UniqueId; +import org.ray.runtime.RayRuntimeInternal; +import org.ray.runtime.task.LocalModeTaskExecutor.LocalActorContext; /** * Task executor for local mode. */ -public class LocalModeTaskExecutor extends TaskExecutor { +public class LocalModeTaskExecutor extends TaskExecutor { - public LocalModeTaskExecutor(AbstractRayRuntime runtime) { + static class LocalActorContext extends TaskExecutor.ActorContext { + + /** + * The worker ID of the actor. + */ + private final UniqueId workerId; + + public LocalActorContext(UniqueId workerId) { + this.workerId = workerId; + } + + public UniqueId getWorkerId() { + return workerId; + } + } + + public LocalModeTaskExecutor(RayRuntimeInternal runtime) { super(runtime); } + @Override + protected LocalActorContext createActorContext() { + return new LocalActorContext(runtime.getWorkerContext().getCurrentWorkerId()); + } + @Override protected void maybeSaveCheckpoint(Object actor, ActorId actorId) { } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java index 6e5346d7d..181ef59cc 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java @@ -4,16 +4,15 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.protobuf.ByteString; import java.nio.ByteBuffer; -import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; -import java.util.Deque; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Random; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.stream.Collectors; @@ -21,9 +20,10 @@ import org.ray.api.BaseActor; import org.ray.api.id.ActorId; import org.ray.api.id.ObjectId; import org.ray.api.id.TaskId; +import org.ray.api.id.UniqueId; import org.ray.api.options.ActorCreationOptions; import org.ray.api.options.CallOptions; -import org.ray.runtime.RayDevRuntime; +import org.ray.runtime.RayRuntimeInternal; import org.ray.runtime.actor.LocalModeRayActor; import org.ray.runtime.context.LocalModeWorkerContext; import org.ray.runtime.functionmanager.FunctionDescriptor; @@ -37,6 +37,7 @@ import org.ray.runtime.generated.Common.TaskSpec; import org.ray.runtime.generated.Common.TaskType; import org.ray.runtime.object.LocalModeObjectStore; import org.ray.runtime.object.NativeRayObject; +import org.ray.runtime.task.TaskExecutor.ActorContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,7 +50,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { private final Map> waitingTasks = new HashMap<>(); private final Object taskAndObjectLock = new Object(); - private final RayDevRuntime runtime; + private final RayRuntimeInternal runtime; + private final TaskExecutor taskExecutor; private final LocalModeObjectStore objectStore; /// The thread pool to execute actor tasks. @@ -58,17 +60,16 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { /// The thread pool to execute normal tasks. private final ExecutorService normalTaskExecutorService; - private final Deque idleTaskExecutors = new ArrayDeque<>(); - private final Map actorTaskExecutors = new HashMap<>(); - private final Object taskExecutorLock = new Object(); - private final ThreadLocal currentTaskExecutor = new ThreadLocal<>(); - public LocalModeTaskSubmitter(RayDevRuntime runtime, LocalModeObjectStore objectStore, - int numberThreads) { + private final Map actorContexts = new ConcurrentHashMap<>(); + + public LocalModeTaskSubmitter(RayRuntimeInternal runtime, TaskExecutor taskExecutor, + LocalModeObjectStore objectStore) { this.runtime = runtime; + this.taskExecutor = taskExecutor; this.objectStore = objectStore; // The thread pool that executes normal tasks in parallel. - normalTaskExecutorService = Executors.newFixedThreadPool(numberThreads); + normalTaskExecutorService = Executors.newCachedThreadPool(); // The thread pool that executes actor tasks in parallel. actorTaskExecutorServices = new HashMap<>(); } @@ -88,46 +89,6 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { } } - /** - * Get the worker of current thread.
NOTE: Cannot be used for multi-threading in worker. - */ - public TaskExecutor getCurrentTaskExecutor() { - return currentTaskExecutor.get(); - } - - /** - * Get a worker from the worker pool to run the given task. - */ - private TaskExecutor getTaskExecutor(TaskSpec task) { - TaskExecutor taskExecutor; - synchronized (taskExecutorLock) { - if (task.getType() == TaskType.ACTOR_TASK) { - taskExecutor = actorTaskExecutors.get(getActorId(task)); - } else if (task.getType() == TaskType.ACTOR_CREATION_TASK) { - taskExecutor = new LocalModeTaskExecutor(runtime); - actorTaskExecutors.put(getActorId(task), taskExecutor); - } else if (idleTaskExecutors.size() > 0) { - taskExecutor = idleTaskExecutors.pop(); - } else { - taskExecutor = new LocalModeTaskExecutor(runtime); - } - } - currentTaskExecutor.set(taskExecutor); - return taskExecutor; - } - - /** - * Return the worker to the worker pool. - */ - private void returnTaskExecutor(TaskExecutor worker, TaskSpec taskSpec) { - currentTaskExecutor.remove(); - synchronized (taskExecutorLock) { - if (taskSpec.getType() == TaskType.NORMAL_TASK) { - idleTaskExecutors.push(worker); - } - } - } - private Set getUnreadyObjects(TaskSpec taskSpec) { Set unreadyObjects = new HashSet<>(); // Check whether task arguments are ready. @@ -257,32 +218,11 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { Set unreadyObjects = getUnreadyObjects(taskSpec); final Runnable runnable = () -> { - TaskExecutor taskExecutor = getTaskExecutor(taskSpec); try { - List args = getFunctionArgs(taskSpec).stream() - .map(arg -> arg.id != null ? - objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0) - : arg.value) - .collect(Collectors.toList()); - ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec); - List returnObjects = taskExecutor - .execute(getJavaFunctionDescriptor(taskSpec).toList(), args); - ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null); - List returnIds = getReturnIds(taskSpec); - for (int i = 0; i < returnIds.size(); i++) { - NativeRayObject putObject; - if (i >= returnObjects.size()) { - // If the task is an actor task or an actor creation task, - // put the dummy object in object store, so those tasks which depends on it - // can be executed. - putObject = new NativeRayObject(new byte[] {1}, null); - } else { - putObject = returnObjects.get(i); - } - objectStore.putRaw(putObject, returnIds.get(i)); - } - } finally { - returnTaskExecutor(taskExecutor, taskSpec); + executeTask(taskSpec); + } catch (Exception ex) { + LOGGER.error("Unexpected exception when executing a task.", ex); + System.exit(-1); } }; @@ -313,6 +253,52 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { } } + private void executeTask(TaskSpec taskSpec) { + ActorContext actorContext = null; + if (taskSpec.getType() == TaskType.ACTOR_TASK) { + actorContext = actorContexts.get(getActorId(taskSpec)); + Preconditions.checkNotNull(actorContext); + } + taskExecutor.setActorContext(actorContext); + List args = getFunctionArgs(taskSpec).stream() + .map(arg -> arg.id != null ? + objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0) + : arg.value) + .collect(Collectors.toList()); + runtime.setIsContextSet(true); + ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec); + UniqueId workerId = actorContext != null + ? ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId() + : UniqueId.randomId(); + ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId); + List returnObjects = taskExecutor + .execute(getJavaFunctionDescriptor(taskSpec).toList(), args); + if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) { + // Update actor context map ASAP in case objectStore.putRaw triggered the next actor task + // on this actor. + actorContexts.put(getActorId(taskSpec), taskExecutor.getActorContext()); + } + // Set this flag to true is necessary because at the end of `taskExecutor.execute()`, + // this flag will be set to false. And `runtime.getWorkerContext()` requires it to be + // true. + runtime.setIsContextSet(true); + ((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null); + runtime.setIsContextSet(false); + List returnIds = getReturnIds(taskSpec); + for (int i = 0; i < returnIds.size(); i++) { + NativeRayObject putObject; + if (i >= returnObjects.size()) { + // If the task is an actor task or an actor creation task, + // put the dummy object in object store, so those tasks which depends on it + // can be executed. + putObject = new NativeRayObject(new byte[]{1}, null); + } else { + putObject = returnObjects.get(i); + } + objectStore.putRaw(putObject, returnIds.get(i)); + } + } + private static JavaFunctionDescriptor getJavaFunctionDescriptor(TaskSpec taskSpec) { org.ray.runtime.generated.Common.FunctionDescriptor functionDescriptor = taskSpec.getFunctionDescriptor(); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskExecutor.java index 36e7259a4..2f6a7579e 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskExecutor.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskExecutor.java @@ -8,39 +8,42 @@ import org.ray.api.Checkpointable.Checkpoint; import org.ray.api.Checkpointable.CheckpointContext; import org.ray.api.id.ActorId; import org.ray.api.id.UniqueId; -import org.ray.runtime.AbstractRayRuntime; +import org.ray.runtime.RayRuntimeInternal; +import org.ray.runtime.task.NativeTaskExecutor.NativeActorContext; /** * Task executor for cluster mode. */ -public class NativeTaskExecutor extends TaskExecutor { +public class NativeTaskExecutor extends TaskExecutor { // TODO(hchen): Use the C++ config. private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20; - /** - * The native pointer of core worker. - */ - private final long nativeCoreWorkerPointer; + static class NativeActorContext extends TaskExecutor.ActorContext { - /** - * Number of tasks executed since last actor checkpoint. - */ - private int numTasksSinceLastCheckpoint = 0; + /** + * Number of tasks executed since last actor checkpoint. + */ + private int numTasksSinceLastCheckpoint = 0; - /** - * IDs of this actor's previous checkpoints. - */ - private List checkpointIds; + /** + * IDs of this actor's previous checkpoints. + */ + private List checkpointIds; - /** - * Timestamp of the last actor checkpoint. - */ - private long lastCheckpointTimestamp = 0; + /** + * Timestamp of the last actor checkpoint. + */ + private long lastCheckpointTimestamp = 0; + } - public NativeTaskExecutor(long nativeCoreWorkerPointer, AbstractRayRuntime runtime) { + public NativeTaskExecutor(RayRuntimeInternal runtime) { super(runtime); - this.nativeCoreWorkerPointer = nativeCoreWorkerPointer; + } + + @Override + protected NativeActorContext createActorContext() { + return new NativeActorContext(); } @Override @@ -48,15 +51,18 @@ public class NativeTaskExecutor extends TaskExecutor { if (!(actor instanceof Checkpointable)) { return; } + NativeActorContext actorContext = getActorContext(); CheckpointContext checkpointContext = new CheckpointContext(actorId, - ++numTasksSinceLastCheckpoint, System.currentTimeMillis() - lastCheckpointTimestamp); + ++actorContext.numTasksSinceLastCheckpoint, + System.currentTimeMillis() - actorContext.lastCheckpointTimestamp); Checkpointable checkpointable = (Checkpointable) actor; if (!checkpointable.shouldCheckpoint(checkpointContext)) { return; } - numTasksSinceLastCheckpoint = 0; - lastCheckpointTimestamp = System.currentTimeMillis(); - UniqueId checkpointId = new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer)); + actorContext.numTasksSinceLastCheckpoint = 0; + actorContext.lastCheckpointTimestamp = System.currentTimeMillis(); + UniqueId checkpointId = new UniqueId(nativePrepareCheckpoint()); + List checkpointIds = actorContext.checkpointIds; checkpointIds.add(checkpointId); if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) { ((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0)); @@ -70,9 +76,10 @@ public class NativeTaskExecutor extends TaskExecutor { if (!(actor instanceof Checkpointable)) { return; } - numTasksSinceLastCheckpoint = 0; - lastCheckpointTimestamp = System.currentTimeMillis(); - checkpointIds = new ArrayList<>(); + NativeActorContext actorContext = getActorContext(); + actorContext.numTasksSinceLastCheckpoint = 0; + actorContext.lastCheckpointTimestamp = System.currentTimeMillis(); + actorContext.checkpointIds = new ArrayList<>(); List availableCheckpoints = runtime.getGcsClient().getCheckpointsForActor(actorId); if (availableCheckpoints.isEmpty()) { @@ -90,13 +97,11 @@ public class NativeTaskExecutor extends TaskExecutor { Preconditions.checkArgument(checkpointValid, "'loadCheckpoint' must return a checkpoint ID that exists in the " + "'availableCheckpoints' list, or null."); - - nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, checkpointId.getBytes()); + nativeNotifyActorResumedFromCheckpoint(checkpointId.getBytes()); } } - private static native byte[] nativePrepareCheckpoint(long nativeCoreWorkerPointer); + private static native byte[] nativePrepareCheckpoint(); - private static native void nativeNotifyActorResumedFromCheckpoint(long nativeCoreWorkerPointer, - byte[] checkpointId); + private static native void nativeNotifyActorResumedFromCheckpoint(byte[] checkpointId); } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskSubmitter.java b/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskSubmitter.java index eed1d9262..f0dacf209 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskSubmitter.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/NativeTaskSubmitter.java @@ -15,30 +15,18 @@ import org.ray.runtime.functionmanager.FunctionDescriptor; */ public class NativeTaskSubmitter implements TaskSubmitter { - /** - * The native pointer of core worker. - */ - private final long nativeCoreWorkerPointer; - - public NativeTaskSubmitter(long nativeCoreWorkerPointer) { - this.nativeCoreWorkerPointer = nativeCoreWorkerPointer; - } - @Override public List submitTask(FunctionDescriptor functionDescriptor, List args, int numReturns, CallOptions options) { - List returnIds = nativeSubmitTask(nativeCoreWorkerPointer, functionDescriptor, args, - numReturns, options); + List returnIds = nativeSubmitTask(functionDescriptor, args, numReturns, options); return returnIds.stream().map(ObjectId::new).collect(Collectors.toList()); } @Override public BaseActor createActor(FunctionDescriptor functionDescriptor, List args, ActorCreationOptions options) { - byte[] actorId = nativeCreateActor(nativeCoreWorkerPointer, functionDescriptor, args, - options); - return NativeRayActor.create(nativeCoreWorkerPointer, actorId, - functionDescriptor.getLanguage()); + byte[] actorId = nativeCreateActor(functionDescriptor, args, options); + return NativeRayActor.create(actorId, functionDescriptor.getLanguage()); } @Override @@ -46,24 +34,18 @@ public class NativeTaskSubmitter implements TaskSubmitter { BaseActor actor, FunctionDescriptor functionDescriptor, List args, int numReturns, CallOptions options) { Preconditions.checkState(actor instanceof NativeRayActor); - List returnIds = nativeSubmitActorTask(nativeCoreWorkerPointer, - actor.getId().getBytes(), functionDescriptor, args, numReturns, - options); + List returnIds = nativeSubmitActorTask(actor.getId().getBytes(), + functionDescriptor, args, numReturns, options); return returnIds.stream().map(ObjectId::new).collect(Collectors.toList()); } - private static native List nativeSubmitTask( - long nativeCoreWorkerPointer, + private static native List nativeSubmitTask(FunctionDescriptor functionDescriptor, + List args, int numReturns, CallOptions callOptions); + + private static native byte[] nativeCreateActor(FunctionDescriptor functionDescriptor, + List args, ActorCreationOptions actorCreationOptions); + + private static native List nativeSubmitActorTask(byte[] actorId, FunctionDescriptor functionDescriptor, List args, int numReturns, CallOptions callOptions); - - private static native byte[] nativeCreateActor( - long nativeCoreWorkerPointer, - FunctionDescriptor functionDescriptor, List args, - ActorCreationOptions actorCreationOptions); - - private static native List nativeSubmitActorTask( - long nativeCoreWorkerPointer, - byte[] actorId, FunctionDescriptor functionDescriptor, List args, - int numReturns, CallOptions callOptions); } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java index 3ca133af7..e044bf677 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java @@ -1,6 +1,7 @@ package org.ray.runtime.task; import com.google.common.base.Preconditions; +import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; import java.util.List; import java.util.concurrent.ConcurrentHashMap; @@ -9,58 +10,75 @@ import org.ray.api.id.ActorId; import org.ray.api.id.JobId; import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; -import org.ray.runtime.AbstractRayRuntime; -import org.ray.runtime.config.RayConfig; -import org.ray.runtime.config.RunMode; +import org.ray.runtime.RayRuntimeInternal; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; import org.ray.runtime.functionmanager.RayFunction; import org.ray.runtime.generated.Common.TaskType; import org.ray.runtime.object.NativeRayObject; import org.ray.runtime.object.ObjectSerializer; +import org.ray.runtime.task.TaskExecutor.ActorContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * The task executor, which executes tasks assigned by raylet continuously. */ -public abstract class TaskExecutor { +public abstract class TaskExecutor { private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class); - // A helper map to help we get the corresponding executor for the given worker in JNI. - private static ConcurrentHashMap taskExecutors - = new ConcurrentHashMap<>(); + protected final RayRuntimeInternal runtime; - protected final AbstractRayRuntime runtime; + private final ConcurrentHashMap actorContextMap = new ConcurrentHashMap<>(); - /** - * The current actor object, if this worker is an actor, otherwise null. - */ - protected Object currentActor = null; + static class ActorContext { - /** - * The exception that failed the actor creation task, if any. - */ - private Exception actorCreationException = null; + /** + * The current actor object, if this worker is an actor, otherwise null. + */ + Object currentActor = null; - protected TaskExecutor(AbstractRayRuntime runtime) { - this.runtime = runtime; - if (RayConfig.getInstance().runMode == RunMode.CLUSTER) { - taskExecutors.put(runtime.getWorkerContext().getCurrentWorkerId(), this); - } + /** + * The exception that failed the actor creation task, if any. + */ + Throwable actorCreationException = null; } - public static TaskExecutor get(byte[] workerId) { - return taskExecutors.get(new UniqueId(workerId)); + TaskExecutor(RayRuntimeInternal runtime) { + this.runtime = runtime; + } + + protected abstract T createActorContext(); + + T getActorContext() { + return actorContextMap.get(runtime.getWorkerContext().getCurrentWorkerId()); + } + + void setActorContext(T actorContext) { + if (actorContext == null) { + // ConcurrentHashMap doesn't allow null values. So just return here. + return; + } + this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext); } protected List execute(List rayFunctionInfo, List argsBytes) { + runtime.setIsContextSet(true); JobId jobId = runtime.getWorkerContext().getCurrentJobId(); TaskType taskType = runtime.getWorkerContext().getCurrentTaskType(); TaskId taskId = runtime.getWorkerContext().getCurrentTaskId(); LOGGER.debug("Executing task {}", taskId); + T actorContext = null; + if (taskType == TaskType.ACTOR_CREATION_TASK) { + actorContext = createActorContext(); + setActorContext(actorContext); + } else if (taskType == TaskType.ACTOR_TASK) { + actorContext = getActorContext(); + Preconditions.checkNotNull(actorContext); + } + List returnObjects = new ArrayList<>(); ClassLoader oldLoader = Thread.currentThread().getContextClassLoader(); JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo); @@ -74,19 +92,26 @@ public abstract class TaskExecutor { // Get local actor object and arguments. Object actor = null; if (taskType == TaskType.ACTOR_TASK) { - if (actorCreationException != null) { - throw actorCreationException; + if (actorContext.actorCreationException != null) { + throw actorContext.actorCreationException; } - actor = currentActor; - + actor = actorContext.currentActor; } Object[] args = ArgumentsBuilder.unwrap(argsBytes, rayFunction.classLoader); // Execute the task. Object result; - if (!rayFunction.isConstructor()) { - result = rayFunction.getMethod().invoke(actor, args); - } else { - result = rayFunction.getConstructor().newInstance(args); + try { + if (!rayFunction.isConstructor()) { + result = rayFunction.getMethod().invoke(actor, args); + } else { + result = rayFunction.getConstructor().newInstance(args); + } + } catch (InvocationTargetException e) { + if (e.getCause() != null) { + throw e.getCause(); + } else { + throw e; + } } // Set result if (taskType != TaskType.ACTOR_CREATION_TASK) { @@ -100,10 +125,10 @@ public abstract class TaskExecutor { } else { // TODO (kfstorm): handle checkpoint in core worker. maybeLoadCheckpoint(result, runtime.getWorkerContext().getCurrentActorId()); - currentActor = result; + actorContext.currentActor = result; } LOGGER.debug("Finished executing task {}", taskId); - } catch (Exception e) { + } catch (Throwable e) { LOGGER.error("Error executing task " + taskId, e); if (taskType != TaskType.ACTOR_CREATION_TASK) { boolean hasReturn = rayFunction != null && rayFunction.hasReturn(); @@ -113,11 +138,12 @@ public abstract class TaskExecutor { .serialize(new RayTaskException("Error executing task " + taskId, e))); } } else { - actorCreationException = e; + actorContext.actorCreationException = e; } } finally { Thread.currentThread().setContextClassLoader(oldLoader); runtime.getWorkerContext().setCurrentClassLoader(null); + runtime.setIsContextSet(false); } return returnObjects; } diff --git a/java/runtime/src/main/resources/ray.default.conf b/java/runtime/src/main/resources/ray.default.conf index 2ea2c4374..2a12eaaac 100644 --- a/java/runtime/src/main/resources/ray.default.conf +++ b/java/runtime/src/main/resources/ray.default.conf @@ -98,12 +98,6 @@ ray { } } - // ---------------------------- - // configurations under SINGLE_PROCESS mode - // ---------------------------- - dev-runtime { - // Number of threads that you process tasks - execution-parallelism: 10 - } - + // Whether we enable job manager to submit and manage job. + enable-job-manager: false } diff --git a/java/test/src/main/java/org/ray/api/TestProgressListener.java b/java/test/src/main/java/org/ray/api/TestProgressListener.java index f1e9ea057..bc3de3953 100644 --- a/java/test/src/main/java/org/ray/api/TestProgressListener.java +++ b/java/test/src/main/java/org/ray/api/TestProgressListener.java @@ -42,6 +42,10 @@ public class TestProgressListener implements IInvokedMethodListener, ITestListen @Override public void onTestFailure(ITestResult result) { printInfo("TEST FAILURE", getFullTestName(result)); + Throwable throwable = result.getThrowable(); + if (throwable != null) { + throwable.printStackTrace(); + } } @Override diff --git a/java/test/src/main/java/org/ray/api/TestUtils.java b/java/test/src/main/java/org/ray/api/TestUtils.java index cd7151fcd..4a8912f89 100644 --- a/java/test/src/main/java/org/ray/api/TestUtils.java +++ b/java/test/src/main/java/org/ray/api/TestUtils.java @@ -1,11 +1,9 @@ package org.ray.api; -import com.google.common.base.Preconditions; import java.io.Serializable; import java.util.function.Supplier; -import org.ray.api.runtime.RayRuntime; -import org.ray.runtime.AbstractRayRuntime; -import org.ray.runtime.RayMultiWorkerNativeRuntime; +import org.ray.runtime.RayRuntimeInternal; +import org.ray.runtime.RayRuntimeProxy; import org.ray.runtime.config.RunMode; import org.testng.Assert; import org.testng.SkipException; @@ -79,12 +77,13 @@ public class TestUtils { Assert.assertEquals(obj.get(), "hi"); } - public static AbstractRayRuntime getRuntime() { - RayRuntime runtime = Ray.internal(); - if (runtime instanceof RayMultiWorkerNativeRuntime) { - runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime(); - } - Preconditions.checkState(runtime instanceof AbstractRayRuntime); - return (AbstractRayRuntime) runtime; + public static RayRuntimeInternal getRuntime() { + return (RayRuntimeInternal) Ray.internal(); + } + + public static RayRuntimeInternal getUnderlyingRuntime() { + RayRuntimeProxy proxy = (RayRuntimeProxy) (java.lang.reflect.Proxy + .getInvocationHandler(Ray.internal())); + return proxy.getRuntimeObject(); } } diff --git a/java/test/src/main/java/org/ray/api/test/ClassLoaderTest.java b/java/test/src/main/java/org/ray/api/test/ClassLoaderTest.java index 891ba0568..73bd9965d 100644 --- a/java/test/src/main/java/org/ray/api/test/ClassLoaderTest.java +++ b/java/test/src/main/java/org/ray/api/test/ClassLoaderTest.java @@ -134,7 +134,7 @@ public class ClassLoaderTest extends BaseTest { FunctionDescriptor.class, Object[].class, ActorCreationOptions.class); createActorMethod.setAccessible(true); return (RayActor) createActorMethod - .invoke(TestUtils.getRuntime(), functionDescriptor, new Object[0], null); + .invoke(TestUtils.getUnderlyingRuntime(), functionDescriptor, new Object[0], null); } private RayObject callActorFunction(RayActor rayActor, @@ -143,6 +143,6 @@ public class ClassLoaderTest extends BaseTest { BaseActor.class, FunctionDescriptor.class, Object[].class, int.class); callActorFunctionMethod.setAccessible(true); return (RayObject) callActorFunctionMethod - .invoke(TestUtils.getRuntime(), rayActor, functionDescriptor, args, numReturns); + .invoke(TestUtils.getUnderlyingRuntime(), rayActor, functionDescriptor, args, numReturns); } } diff --git a/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java b/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java index f862eb95c..4058b7e52 100644 --- a/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java @@ -29,7 +29,8 @@ public class ClientExceptionTest extends BaseTest { try { TimeUnit.SECONDS.sleep(1); // kill raylet - RunManager runManager = ((RayNativeRuntime) TestUtils.getRuntime()).getRunManager(); + RunManager runManager = + ((RayNativeRuntime) TestUtils.getUnderlyingRuntime()).getRunManager(); for (Process process : runManager.getProcesses("raylet")) { runManager.terminateProcess("raylet", process); } diff --git a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java index e493959b6..f1ed2a65a 100644 --- a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java +++ b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java @@ -14,6 +14,7 @@ import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.TestUtils; import org.ray.api.WaitResult; +import org.ray.api.exception.RayException; import org.ray.api.id.ActorId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -135,6 +136,105 @@ public class MultiThreadingTest extends BaseTest { Assert.assertEquals(actorId, actorIdTester.getId()); } + static boolean testMissingWrapRunnable() throws InterruptedException { + final RayObject fooObject = Ray.put(1); + final RayActor fooActor = Ray.createActor(Echo::new); + final Runnable[] runnables = new Runnable[]{ + () -> Ray.put(1), + () -> Ray.get(fooObject.getId()), + fooObject::get, + () -> Ray.wait(ImmutableList.of(fooObject)), + Ray::getRuntimeContext, + () -> Ray.call(MultiThreadingTest::echo, 1), + () -> Ray.createActor(Echo::new), + () -> fooActor.call(Echo::echo, 1), + }; + + // It's OK to run them in main thread. + for (Runnable runnable : runnables) { + runnable.run(); + } + + Exception[] exception = new Exception[1]; + + Thread thread = new Thread(Ray.wrapRunnable(() -> { + try { + // It would be OK to run them in another thread if wrapped the runnable. + for (Runnable runnable : runnables) { + runnable.run(); + } + } catch (Exception ex) { + exception[0] = ex; + } + })); + thread.start(); + thread.join(); + if (exception[0] != null) { + throw new RuntimeException("Exception occurred in thread.", exception[0]); + } + + thread = new Thread(() -> { + try { + // It wouldn't be OK to run them in another thread if not wrapped the runnable. + for (Runnable runnable : runnables) { + Assert.expectThrows(RayException.class, runnable::run); + } + } catch (Exception ex) { + exception[0] = ex; + } + }); + thread.start(); + thread.join(); + if (exception[0] != null) { + throw new RuntimeException("Exception occurred in thread.", exception[0]); + } + + Runnable[] wrappedRunnables = new Runnable[runnables.length]; + for (int i = 0; i < runnables.length; i++) { + wrappedRunnables[i] = Ray.wrapRunnable(runnables[i]); + } + // It would be OK to run the wrapped runnables in the current thread. + for (Runnable runnable : wrappedRunnables) { + runnable.run(); + } + + // It would be OK to invoke Ray APIs after executing a wrapped runnable in the current thread. + wrappedRunnables[0].run(); + runnables[0].run(); + + // Return true here to make the Ray.call returns an RayObject. + return true; + } + + @Test + public void testMissingWrapRunnableInDriver() throws InterruptedException { + testMissingWrapRunnable(); + } + + @Test + public void testMissingWrapRunnableInWorker() { + Ray.call(MultiThreadingTest::testMissingWrapRunnable).get(); + } + + @Test + public void testGetAndSetAsyncContext() throws InterruptedException { + Object asyncContext = Ray.getAsyncContext(); + Exception[] exception = new Exception[1]; + Thread thread = new Thread(() -> { + try { + Ray.setAsyncContext(asyncContext); + Ray.put(1); + } catch (Exception ex) { + exception[0] = ex; + } + }); + thread.start(); + thread.join(); + if (exception[0] != null) { + throw new RuntimeException("Exception occurred in thread.", exception[0]); + } + } + private static void runTestCaseInMultipleThreads(Runnable testCase, int numRepeats) { ExecutorService service = Executors.newFixedThreadPool(NUM_THREADS); diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 85c4512e1..05b0dfca7 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -17,7 +17,7 @@ from ray.includes.common cimport ( CBuffer, CRayObject ) -from ray.includes.libcoreworker cimport CCoreWorker +from ray.includes.libcoreworker cimport CFiberEvent from ray.includes.unique_ids cimport ( CObjectID, CActorID @@ -72,7 +72,7 @@ cdef class ActorID(BaseID): cdef class CoreWorker: cdef: - unique_ptr[CCoreWorker] core_worker + c_bool is_driver object async_thread object async_event_loop object plasma_event_handler @@ -85,6 +85,7 @@ cdef class CoreWorker: cdef store_task_outputs( self, worker, outputs, const c_vector[CObjectID] return_ids, c_vector[shared_ptr[CRayObject]] *returns) + cdef yield_current_fiber(self, CFiberEvent &fiber_event) cdef class FunctionDescriptor: cdef: diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index acedced0e..76cf05c4e 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -69,7 +69,8 @@ from ray.includes.unique_ids cimport ( ) from ray.includes.libcoreworker cimport ( CActorCreationOptions, - CCoreWorker, + CCoreWorkerOptions, + CCoreWorkerProcess, CTaskOptions, ResourceMappingType, CFiberEvent, @@ -312,7 +313,7 @@ cdef execute_task( dict execution_infos = manager.execution_infos CoreWorker core_worker = worker.core_worker JobID job_id = core_worker.get_current_job_id() - CTaskID task_id = core_worker.core_worker.get().GetCurrentTaskId() + TaskID task_id = core_worker.get_current_task_id() CFiberEvent task_done_event # Automatically restrict the GPUs available to this task. @@ -339,7 +340,7 @@ cdef execute_task( function_name = execution_info.function_name extra_data = (b'{"name": ' + function_name.encode("ascii") + - b' "task_id": ' + task_id.Hex() + b'}') + b' "task_id": ' + task_id.hex().encode("ascii") + b'}') if task_type == TASK_TYPE_NORMAL_TASK: title = "ray::{}()".format(function_name) @@ -396,9 +397,7 @@ cdef execute_task( monitor_state.unregister_coroutine(coroutine) future.add_done_callback(callback) - with nogil: - (core_worker.core_worker.get() - .YieldCurrentFiber(task_done_event)) + core_worker.yield_current_fiber(task_done_event) return future.result() @@ -499,8 +498,7 @@ cdef CRayStatus task_execution_handler( const c_vector[shared_ptr[CRayObject]] &c_args, const c_vector[CObjectID] &c_arg_reference_ids, const c_vector[CObjectID] &c_return_ids, - c_vector[shared_ptr[CRayObject]] *returns, - const CWorkerID &c_worker_id) nogil: + c_vector[shared_ptr[CRayObject]] *returns) nogil: with gil: try: @@ -645,43 +643,76 @@ cdef class CoreWorker: def __cinit__(self, is_driver, store_socket, raylet_socket, JobID job_id, GcsClientOptions gcs_options, log_dir, - node_ip_address, node_manager_port, local_mode): - use_driver = is_driver or local_mode - self.core_worker.reset(new CCoreWorker( - WORKER_TYPE_DRIVER if use_driver else WORKER_TYPE_WORKER, - LANGUAGE_PYTHON, store_socket.encode("ascii"), - raylet_socket.encode("ascii"), job_id.native(), - gcs_options.native()[0], log_dir.encode("utf-8"), - node_ip_address.encode("utf-8"), node_manager_port, - task_execution_handler, check_signals, gc_collect, - get_py_stack, True, local_mode)) + node_ip_address, node_manager_port, local_mode, + driver_name, stdout_file, stderr_file): + self.is_driver = is_driver self.is_local_mode = local_mode + cdef CCoreWorkerOptions options = CCoreWorkerOptions() + options.worker_type = ( + WORKER_TYPE_DRIVER if is_driver else WORKER_TYPE_WORKER) + options.language = LANGUAGE_PYTHON + options.store_socket = store_socket.encode("ascii") + options.raylet_socket = raylet_socket.encode("ascii") + options.job_id = job_id.native() + options.gcs_options = gcs_options.native()[0] + options.log_dir = log_dir.encode("utf-8") + options.install_failure_signal_handler = True + options.node_ip_address = node_ip_address.encode("utf-8") + options.node_manager_port = node_manager_port + options.driver_name = driver_name + options.stdout_file = stdout_file + options.stderr_file = stderr_file + options.task_execution_callback = task_execution_handler + options.check_signals = check_signals + options.gc_collect = gc_collect + options.get_lang_stack = get_py_stack + options.ref_counting_enabled = True + options.is_local_mode = local_mode + options.num_workers = 1 + + CCoreWorkerProcess.Initialize(options) + + def __dealloc__(self): + with nogil: + # If it's a worker, the core worker process should have been + # shutdown. So we can't call + # `CCoreWorkerProcess.GetCoreWorker().GetWorkerType()` here. + # Instead, we use the cached `is_driver` flag to test if it's a + # driver. + if self.is_driver: + CCoreWorkerProcess.Shutdown() + def run_task_loop(self): with nogil: - self.core_worker.get().StartExecutingTasks() + CCoreWorkerProcess.RunTaskExecutionLoop() def get_current_task_id(self): - return TaskID(self.core_worker.get().GetCurrentTaskId().Binary()) + return TaskID( + CCoreWorkerProcess.GetCoreWorker().GetCurrentTaskId().Binary()) def get_current_job_id(self): - return JobID(self.core_worker.get().GetCurrentJobId().Binary()) + return JobID( + CCoreWorkerProcess.GetCoreWorker().GetCurrentJobId().Binary()) def get_actor_id(self): - return ActorID(self.core_worker.get().GetActorId().Binary()) + return ActorID( + CCoreWorkerProcess.GetCoreWorker().GetActorId().Binary()) def set_webui_display(self, key, message): - self.core_worker.get().SetWebuiDisplay(key, message) + CCoreWorkerProcess.GetCoreWorker().SetWebuiDisplay(key, message) def set_actor_title(self, title): - self.core_worker.get().SetActorTitle(title) + CCoreWorkerProcess.GetCoreWorker().SetActorTitle(title) def set_plasma_added_callback(self, plasma_event_handler): self.plasma_event_handler = plasma_event_handler - self.core_worker.get().SetPlasmaAddedCallback(async_plasma_callback) + CCoreWorkerProcess.GetCoreWorker().SetPlasmaAddedCallback( + async_plasma_callback) def subscribe_to_plasma_object(self, ObjectID object_id): - self.core_worker.get().SubscribeToPlasmaAdd(object_id.native()) + CCoreWorkerProcess.GetCoreWorker().SubscribeToPlasmaAdd( + object_id.native()) def get_plasma_event_handler(self): return self.plasma_event_handler @@ -694,7 +725,7 @@ cdef class CoreWorker: c_vector[CObjectID] c_object_ids = ObjectIDsToVector(object_ids) with nogil: - check_status(self.core_worker.get().Get( + check_status(CCoreWorkerProcess.GetCoreWorker().Get( c_object_ids, timeout_ms, &results)) return RayObjectsToDataMetadataPairs(results) @@ -705,7 +736,7 @@ cdef class CoreWorker: CObjectID c_object_id = object_id.native() with nogil: - check_status(self.core_worker.get().Contains( + check_status(CCoreWorkerProcess.GetCoreWorker().Contains( c_object_id, &has_object)) return has_object @@ -716,13 +747,13 @@ cdef class CoreWorker: CObjectID *c_object_id, shared_ptr[CBuffer] *data): if object_id is None: with nogil: - check_status(self.core_worker.get().Create( + check_status(CCoreWorkerProcess.GetCoreWorker().Create( metadata, data_size, contained_ids, c_object_id, data)) else: c_object_id[0] = object_id.native() with nogil: - check_status(self.core_worker.get().Create( + check_status(CCoreWorkerProcess.GetCoreWorker().Create( metadata, data_size, c_object_id[0], data)) @@ -752,7 +783,7 @@ cdef class CoreWorker: write_serialized_object(serialized_object, data) if self.is_local_mode: c_object_id_vector.push_back(c_object_id) - check_status(self.core_worker.get().Put( + check_status(CCoreWorkerProcess.GetCoreWorker().Put( CRayObject(data, metadata, c_object_id_vector), c_object_id_vector, c_object_id)) else: @@ -760,7 +791,7 @@ cdef class CoreWorker: # Using custom object IDs is not supported because we can't # track their lifecycle, so we don't pin the object in this # case. - check_status(self.core_worker.get().Seal( + check_status(CCoreWorkerProcess.GetCoreWorker().Seal( c_object_id, pin_object and object_id is None)) @@ -775,7 +806,7 @@ cdef class CoreWorker: wait_ids = ObjectIDsToVector(object_ids) with nogil: - check_status(self.core_worker.get().Wait( + check_status(CCoreWorkerProcess.GetCoreWorker().Wait( wait_ids, num_returns, timeout_ms, &results)) assert len(results) == len(object_ids) @@ -795,19 +826,19 @@ cdef class CoreWorker: c_vector[CObjectID] free_ids = ObjectIDsToVector(object_ids) with nogil: - check_status(self.core_worker.get().Delete( + check_status(CCoreWorkerProcess.GetCoreWorker().Delete( free_ids, local_only, delete_creating_tasks)) def global_gc(self): with nogil: - self.core_worker.get().TriggerGlobalGC() + CCoreWorkerProcess.GetCoreWorker().TriggerGlobalGC() def set_object_store_client_options(self, client_name, int64_t limit_bytes): try: logger.debug("Setting plasma memory limit to {} for {}".format( limit_bytes, client_name)) - check_status(self.core_worker.get().SetClientOptions( + check_status(CCoreWorkerProcess.GetCoreWorker().SetClientOptions( client_name.encode("ascii"), limit_bytes)) except RayError as e: self.dump_object_store_memory_usage() @@ -820,7 +851,7 @@ cdef class CoreWorker: limit_bytes, client_name, e)) def dump_object_store_memory_usage(self): - message = self.core_worker.get().MemoryUsageString() + message = CCoreWorkerProcess.GetCoreWorker().MemoryUsageString() logger.warning("Local object store memory usage:\n{}\n".format( message.decode("utf-8"))) @@ -847,7 +878,7 @@ cdef class CoreWorker: prepare_args(self, args, &args_vector) with nogil: - check_status(self.core_worker.get().SubmitTask( + check_status(CCoreWorkerProcess.GetCoreWorker().SubmitTask( ray_function, args_vector, task_options, &return_ids, max_retries)) @@ -880,7 +911,7 @@ cdef class CoreWorker: prepare_args(self, args, &args_vector) with nogil: - check_status(self.core_worker.get().CreateActor( + check_status(CCoreWorkerProcess.GetCoreWorker().CreateActor( ray_function, args_vector, CActorCreationOptions( max_reconstructions, max_concurrency, @@ -916,10 +947,11 @@ cdef class CoreWorker: prepare_args(self, args, &args_vector) with nogil: - check_status(self.core_worker.get().SubmitActorTask( - c_actor_id, - ray_function, - args_vector, task_options, &return_ids)) + check_status( + CCoreWorkerProcess.GetCoreWorker().SubmitActorTask( + c_actor_id, + ray_function, + args_vector, task_options, &return_ids)) return VectorToObjectIDs(return_ids) @@ -928,13 +960,13 @@ cdef class CoreWorker: CActorID c_actor_id = actor_id.native() with nogil: - check_status(self.core_worker.get().KillActor( + check_status(CCoreWorkerProcess.GetCoreWorker().KillActor( c_actor_id, True, no_reconstruction)) def resource_ids(self): cdef: ResourceMappingType resource_mapping = ( - self.core_worker.get().GetResourceIDs()) + CCoreWorkerProcess.GetCoreWorker().GetResourceIDs()) unordered_map[ c_string, c_vector[pair[int64_t, double]] ].iterator iterator = resource_mapping.begin() @@ -955,13 +987,14 @@ cdef class CoreWorker: def profile_event(self, c_string event_type, object extra_data=None): return ProfileEvent.make( - self.core_worker.get().CreateProfileEvent(event_type), + CCoreWorkerProcess.GetCoreWorker().CreateProfileEvent(event_type), extra_data) def remove_actor_handle_reference(self, ActorID actor_id): cdef: CActorID c_actor_id = actor_id.native() - self.core_worker.get().RemoveActorHandleReference(c_actor_id) + CCoreWorkerProcess.GetCoreWorker().RemoveActorHandleReference( + c_actor_id) def deserialize_and_register_actor_handle(self, const c_string &bytes, ObjectID @@ -974,9 +1007,10 @@ cdef class CoreWorker: worker = ray.worker.global_worker worker.check_connected() manager = worker.function_actor_manager - c_actor_id = self.core_worker.get().DeserializeAndRegisterActorHandle( - bytes, c_outer_object_id) - check_status(self.core_worker.get().GetActorHandle( + c_actor_id = (CCoreWorkerProcess.GetCoreWorker() + .DeserializeAndRegisterActorHandle( + bytes, c_outer_object_id)) + check_status(CCoreWorkerProcess.GetCoreWorker().GetActorHandle( c_actor_id, &c_actor_handle)) actor_id = ActorID(c_actor_id.Binary()) job_id = JobID(c_actor_handle.CreationJobID().Binary()) @@ -1017,24 +1051,26 @@ cdef class CoreWorker: cdef: c_string output CObjectID c_actor_handle_id - check_status(self.core_worker.get().SerializeActorHandle( + check_status(CCoreWorkerProcess.GetCoreWorker().SerializeActorHandle( actor_id.native(), &output, &c_actor_handle_id)) return output, ObjectID(c_actor_handle_id.Binary()) def add_object_id_reference(self, ObjectID object_id): # Note: faster to not release GIL for short-running op. - self.core_worker.get().AddLocalReference(object_id.native()) + CCoreWorkerProcess.GetCoreWorker().AddLocalReference( + object_id.native()) def remove_object_id_reference(self, ObjectID object_id): # Note: faster to not release GIL for short-running op. - self.core_worker.get().RemoveLocalReference(object_id.native()) + CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference( + object_id.native()) def serialize_and_promote_object_id(self, ObjectID object_id): cdef: CObjectID c_object_id = object_id.native() CTaskID c_owner_id = CTaskID.Nil() CAddress c_owner_address = CAddress() - self.core_worker.get().PromoteToPlasmaAndGetOwnershipInfo( + CCoreWorkerProcess.GetCoreWorker().PromoteToPlasmaAndGetOwnershipInfo( c_object_id, &c_owner_id, &c_owner_address) return (object_id, TaskID(c_owner_id.Binary()), @@ -1053,11 +1089,12 @@ cdef class CoreWorker: CAddress c_owner_address = CAddress() c_owner_address.ParseFromString(serialized_owner_address) - self.core_worker.get().RegisterOwnershipInfoAndResolveFuture( + (CCoreWorkerProcess.GetCoreWorker() + .RegisterOwnershipInfoAndResolveFuture( c_object_id, c_outer_object_id, c_owner_id, - c_owner_address) + c_owner_address)) cdef store_task_outputs( self, worker, outputs, const c_vector[CObjectID] return_ids, @@ -1088,8 +1125,10 @@ cdef class CoreWorker: ObjectIDsToVector(serialized_object.contained_object_ids)) with nogil: - check_status(self.core_worker.get().AllocateReturnObjects( - return_ids, data_sizes, metadatas, contained_ids, returns)) + check_status(CCoreWorkerProcess.GetCoreWorker() + .AllocateReturnObjects( + return_ids, data_sizes, metadatas, contained_ids, + returns)) for i, serialized_object in enumerate(serialized_objects): # A nullptr is returned if the object already exists. @@ -1099,7 +1138,7 @@ cdef class CoreWorker: if self.is_local_mode: return_ids_vector.push_back(return_ids[i]) check_status( - self.core_worker.get().Put( + CCoreWorkerProcess.GetCoreWorker().Put( CRayObject(returns[0][i].get().GetData(), returns[0][i].get().GetMetadata(), return_ids_vector), @@ -1138,7 +1177,7 @@ cdef class CoreWorker: future = asyncio.run_coroutine_threadsafe(coroutine, loop) future.add_done_callback(lambda _: event.Notify()) with nogil: - (self.core_worker.get() + (CCoreWorkerProcess.GetCoreWorker() .YieldCurrentFiber(event)) return future.result() @@ -1149,14 +1188,20 @@ cdef class CoreWorker: self.async_thread.join() def current_actor_is_asyncio(self): - return self.core_worker.get().GetWorkerContext().CurrentActorIsAsync() + return (CCoreWorkerProcess.GetCoreWorker().GetWorkerContext() + .CurrentActorIsAsync()) + + cdef yield_current_fiber(self, CFiberEvent &fiber_event): + with nogil: + CCoreWorkerProcess.GetCoreWorker().YieldCurrentFiber(fiber_event) def get_all_reference_counts(self): cdef: unordered_map[CObjectID, pair[size_t, size_t]] c_ref_counts unordered_map[CObjectID, pair[size_t, size_t]].iterator it - c_ref_counts = self.core_worker.get().GetAllReferenceCounts() + c_ref_counts = ( + CCoreWorkerProcess.GetCoreWorker().GetAllReferenceCounts()) it = c_ref_counts.begin() ref_counts = {} @@ -1170,7 +1215,7 @@ cdef class CoreWorker: return ref_counts def in_memory_store_get_async(self, ObjectID object_id, future): - self.core_worker.get().GetAsync( + CCoreWorkerProcess.GetCoreWorker().GetAsync( object_id.native(), async_set_result_callback, async_retry_with_plasma_callback, @@ -1178,7 +1223,7 @@ cdef class CoreWorker: def push_error(self, JobID job_id, error_type, error_message, double timestamp): - check_status(self.core_worker.get().PushError( + check_status(CCoreWorkerProcess.GetCoreWorker().PushError( job_id.native(), error_type.encode("ascii"), error_message.encode("ascii"), timestamp)) @@ -1190,18 +1235,21 @@ cdef class CoreWorker: # PrepareActorCheckpoint will wait for raylet's reply, release # the GIL so other Python threads can run. with nogil: - check_status(self.core_worker.get().PrepareActorCheckpoint( - c_actor_id, &checkpoint_id)) + check_status( + CCoreWorkerProcess.GetCoreWorker() + .PrepareActorCheckpoint(c_actor_id, &checkpoint_id)) return ActorCheckpointID(checkpoint_id.Binary()) def notify_actor_resumed_from_checkpoint(self, ActorID actor_id, ActorCheckpointID checkpoint_id): - check_status(self.core_worker.get().NotifyActorResumedFromCheckpoint( - actor_id.native(), checkpoint_id.native())) + check_status( + CCoreWorkerProcess.GetCoreWorker() + .NotifyActorResumedFromCheckpoint( + actor_id.native(), checkpoint_id.native())) def set_resource(self, basestring resource_name, double capacity, ClientID client_id): - self.core_worker.get().SetResource( + CCoreWorkerProcess.GetCoreWorker().SetResource( resource_name.encode("ascii"), capacity, CClientID.FromBinary(client_id.binary())) diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index d903f7e60..d54e7a015 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -17,7 +17,6 @@ from ray.includes.unique_ids cimport ( CJobID, CTaskID, CObjectID, - CWorkerID, ) from ray.includes.common cimport ( CAddress, @@ -80,31 +79,9 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_string ExtensionData() const cdef cppclass CCoreWorker "ray::CoreWorker": - CCoreWorker(const CWorkerType worker_type, const CLanguage language, - const c_string &store_socket, - const c_string &raylet_socket, const CJobID &job_id, - const CGcsClientOptions &gcs_options, - const c_string &log_dir, const c_string &node_ip_address, - int node_manager_port, - CRayStatus ( - CTaskType task_type, - const CRayFunction &ray_function, - const unordered_map[c_string, double] &resources, - const c_vector[shared_ptr[CRayObject]] &args, - const c_vector[CObjectID] &arg_reference_ids, - const c_vector[CObjectID] &return_ids, - c_vector[shared_ptr[CRayObject]] *returns, - const CWorkerID &worker_id) nogil, - CRayStatus() nogil, - void() nogil, - void(c_string *stack_out) nogil, - c_bool ref_counting_enabled, - c_bool local_worker) CWorkerType &GetWorkerType() CLanguage &GetLanguage() - void StartExecutingTasks() - CRayStatus SubmitTask( const CRayFunction &function, const c_vector[CTaskArg] &args, const CTaskOptions &options, c_vector[CObjectID] *return_ids, @@ -206,3 +183,46 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: void SetPlasmaAddedCallback(plasma_callback_function callback) void SubscribeToPlasmaAdd(const CObjectID &object_id) + + cdef cppclass CCoreWorkerOptions "ray::CoreWorkerOptions": + CWorkerType worker_type + CLanguage language + c_string store_socket + c_string raylet_socket + CJobID job_id + CGcsClientOptions gcs_options + c_string log_dir + c_bool install_failure_signal_handler + c_string node_ip_address + int node_manager_port + c_string driver_name + c_string stdout_file + c_string stderr_file + (CRayStatus( + CTaskType task_type, + const CRayFunction &ray_function, + const unordered_map[c_string, double] &resources, + const c_vector[shared_ptr[CRayObject]] &args, + const c_vector[CObjectID] &arg_reference_ids, + const c_vector[CObjectID] &return_ids, + c_vector[shared_ptr[CRayObject]] *returns) nogil + ) task_execution_callback + (CRayStatus() nogil) check_signals + (void() nogil) gc_collect + (void(c_string *stack_out) nogil) get_lang_stack + c_bool ref_counting_enabled + c_bool is_local_mode + int num_workers + CCoreWorkerOptions() + + cdef cppclass CCoreWorkerProcess "ray::CoreWorkerProcess": + @staticmethod + void Initialize(const CCoreWorkerOptions &options) + # Only call this in CoreWorker.__cinit__, + # use CoreWorker.core_worker to access C++ CoreWorker. + @staticmethod + CCoreWorker &GetCoreWorker() + @staticmethod + void Shutdown() + @staticmethod + void RunTaskExecutionLoop() diff --git a/python/ray/worker.py b/python/ray/worker.py index 7bea64849..18313d91a 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1173,27 +1173,14 @@ def connect(node, ray.state.state._initialize_global_state( node.redis_address, redis_password=node.redis_password) - # Register the worker with Redis. + driver_name = "" + log_stdout_file_name = "" + log_stderr_file_name = "" if mode == SCRIPT_MODE: - # The concept of a driver is the same as the concept of a "job". - # Register the driver/job with Redis here. import __main__ as main - driver_info = { - "node_ip_address": node.node_ip_address, - "driver_id": worker.worker_id, - "start_time": time.time(), - "plasma_store_socket": node.plasma_store_socket_name, - "raylet_socket": node.raylet_socket_name, - "name": (main.__file__ - if hasattr(main, "__file__") else "INTERACTIVE MODE") - } - worker.redis_client.hmset(b"Drivers:" + worker.worker_id, driver_info) + driver_name = (main.__file__ + if hasattr(main, "__file__") else "INTERACTIVE MODE") elif mode == WORKER_MODE: - # Register the worker with Redis. - worker_dict = { - "node_ip_address": node.node_ip_address, - "plasma_store_socket": node.plasma_store_socket_name, - } # Check the RedirectOutput key in Redis and based on its value redirect # worker output and error to their own files. # This key is set in services.py when Redis is started. @@ -1224,14 +1211,12 @@ def connect(node, print("Ray worker pid: {}".format(os.getpid()), file=sys.stderr) sys.stdout.flush() sys.stderr.flush() - - worker_dict["stdout_file"] = os.path.abspath( + log_stdout_file_name = os.path.abspath( (log_stdout_file if log_stdout_file is not None else sys.stdout).name) - worker_dict["stderr_file"] = os.path.abspath( + log_stderr_file_name = os.path.abspath( (log_stderr_file if log_stderr_file is not None else sys.stderr).name) - worker.redis_client.hmset(b"Workers:" + worker.worker_id, worker_dict) elif not LOCAL_MODE: raise ValueError( "Invalid worker mode. Expected DRIVER, WORKER or LOCAL.") @@ -1242,9 +1227,19 @@ def connect(node, node.redis_password, ) worker.core_worker = ray._raylet.CoreWorker( - (mode == SCRIPT_MODE), node.plasma_store_socket_name, - node.raylet_socket_name, job_id, gcs_options, node.get_logs_dir_path(), - node.node_ip_address, node.node_manager_port, mode == LOCAL_MODE) + (mode == SCRIPT_MODE or mode == LOCAL_MODE), + node.plasma_store_socket_name, + node.raylet_socket_name, + job_id, + gcs_options, + node.get_logs_dir_path(), + node.node_ip_address, + node.node_manager_port, + (mode == LOCAL_MODE), + driver_name, + log_stdout_file_name, + log_stderr_file_name, + ) if driver_object_store_memory is not None: worker.core_worker.set_object_store_client_options( diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index cd4998e8c..c0b568199 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -63,10 +63,10 @@ struct WorkerThreadContext { thread_local std::unique_ptr WorkerContext::thread_context_ = nullptr; -WorkerContext::WorkerContext(WorkerType worker_type, const JobID &job_id) +WorkerContext::WorkerContext(WorkerType worker_type, const WorkerID &worker_id, + const JobID &job_id) : worker_type_(worker_type), - worker_id_(worker_type_ == WorkerType::DRIVER ? ComputeDriverIdFromJob(job_id) - : WorkerID::FromRandom()), + worker_id_(worker_id), current_job_id_(worker_type_ == WorkerType::DRIVER ? job_id : JobID::Nil()), current_actor_id_(ActorID::Nil()), main_thread_id_(boost::this_thread::get_id()) { diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 6ee7f5604..cfea89b3f 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -26,7 +26,7 @@ struct WorkerThreadContext; class WorkerContext { public: - WorkerContext(WorkerType worker_type, const JobID &job_id); + WorkerContext(WorkerType worker_type, const WorkerID &worker_id, const JobID &job_id); const WorkerType GetWorkerType() const; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 783cff3dd..6f625c101 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -75,63 +75,214 @@ void GroupObjectIdsByStoreProvider(const std::vector &object_ids, namespace ray { -CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, - const std::string &store_socket, const std::string &raylet_socket, - const JobID &job_id, const gcs::GcsClientOptions &gcs_options, - const std::string &log_dir, const std::string &node_ip_address, - int node_manager_port, - const TaskExecutionCallback &task_execution_callback, - std::function check_signals, - std::function gc_collect, - std::function get_lang_stack, - bool ref_counting_enabled, bool local_mode) - : worker_type_(worker_type), - language_(language), - log_dir_(log_dir), - ref_counting_enabled_(ref_counting_enabled), - is_local_mode_(local_mode), - check_signals_(check_signals), - gc_collect_(gc_collect), - get_call_site_(RayConfig::instance().record_ref_creation_sites() ? get_lang_stack - : nullptr), - worker_context_(worker_type, job_id), +std::unique_ptr CoreWorkerProcess::instance_; + +thread_local std::weak_ptr CoreWorkerProcess::current_core_worker_; + +void CoreWorkerProcess::Initialize(const CoreWorkerOptions &options) { + RAY_CHECK(!instance_) << "The process is already initialized for core worker."; + instance_ = std::unique_ptr(new CoreWorkerProcess(options)); +} + +void CoreWorkerProcess::Shutdown() { + if (!instance_) { + return; + } + RAY_CHECK(instance_->options_.worker_type == WorkerType::DRIVER) + << "The `Shutdown` interface is for driver only."; + RAY_CHECK(instance_->global_worker_); + instance_->global_worker_->Disconnect(); + instance_->global_worker_->Shutdown(); + instance_->RemoveWorker(instance_->global_worker_); + instance_.reset(); +} + +bool CoreWorkerProcess::IsInitialized() { return instance_ != nullptr; } + +CoreWorkerProcess::CoreWorkerProcess(const CoreWorkerOptions &options) + : options_(options), + global_worker_id_( + options.worker_type == WorkerType::DRIVER + ? ComputeDriverIdFromJob(options_.job_id) + : (options_.num_workers == 1 ? WorkerID::FromRandom() : WorkerID::Nil())) { + // Initialize logging if log_dir is passed. Otherwise, it must be initialized + // and cleaned up by the caller. + if (options_.log_dir != "") { + std::stringstream app_name; + app_name << LanguageString(options_.language) << "-core-" + << WorkerTypeString(options_.worker_type); + if (!global_worker_id_.IsNil()) { + app_name << "-" << global_worker_id_; + } + RayLog::StartRayLog(app_name.str(), RayLogLevel::INFO, options_.log_dir); + if (options_.install_failure_signal_handler) { + RayLog::InstallFailureSignalHandler(); + } + } + + RAY_CHECK(options_.num_workers > 0); + if (options_.worker_type == WorkerType::DRIVER) { + // Driver process can only contain one worker. + RAY_CHECK(options_.num_workers == 1); + } + + RAY_LOG(INFO) << "Constructing CoreWorkerProcess. pid: " << getpid(); + + if (options_.num_workers == 1) { + // We need to create the worker instance here if: + // 1. This is a driver process. In this case, the driver is ready to use right after + // the CoreWorkerProcess::Initialize. + // 2. This is a Python worker process. In this case, Python will invoke some core + // worker APIs before `CoreWorkerProcess::RunTaskExecutionLoop` is called. So we need + // to create the worker instance here. One example of invocations is + // https://github.com/ray-project/ray/blob/45ce40e5d44801193220d2c546be8de0feeef988/python/ray/worker.py#L1281. + if (options_.worker_type == WorkerType::DRIVER || + options_.language == Language::PYTHON) { + CreateWorker(); + } + } +} + +CoreWorkerProcess::~CoreWorkerProcess() { + RAY_LOG(INFO) << "Destructing CoreWorkerProcess. pid: " << getpid(); + { + // Check that all `CoreWorker` instances have been removed. + absl::ReaderMutexLock lock(&worker_map_mutex_); + RAY_CHECK(workers_.empty()); + } + if (options_.log_dir != "") { + RayLog::ShutDownRayLog(); + } +} + +void CoreWorkerProcess::EnsureInitialized() { + RAY_CHECK(instance_) << "The core worker process is not initialized yet or already " + << "shutdown."; +} + +CoreWorker &CoreWorkerProcess::GetCoreWorker() { + EnsureInitialized(); + if (instance_->options_.num_workers == 1) { + return *instance_->global_worker_; + } + auto ptr = current_core_worker_.lock(); + RAY_CHECK(ptr != nullptr) + << "The current thread is not bound with a core worker instance."; + return *ptr; +} + +void CoreWorkerProcess::SetCurrentThreadWorkerId(const WorkerID &worker_id) { + EnsureInitialized(); + if (instance_->options_.num_workers == 1) { + RAY_CHECK(instance_->global_worker_->GetWorkerID() == worker_id); + return; + } + current_core_worker_ = instance_->GetWorker(worker_id); +} + +std::shared_ptr CoreWorkerProcess::GetWorker( + const WorkerID &worker_id) const { + absl::ReaderMutexLock lock(&worker_map_mutex_); + auto it = workers_.find(worker_id); + RAY_CHECK(it != workers_.end()) << "Worker " << worker_id << " not found."; + return it->second; +} + +std::shared_ptr CoreWorkerProcess::CreateWorker() { + auto worker = std::make_shared( + options_, + global_worker_id_ != WorkerID::Nil() ? global_worker_id_ : WorkerID::FromRandom()); + RAY_LOG(INFO) << "Worker " << worker->GetWorkerID() << " is created."; + if (options_.num_workers == 1) { + global_worker_ = worker; + } + current_core_worker_ = worker; + + absl::MutexLock lock(&worker_map_mutex_); + workers_.emplace(worker->GetWorkerID(), worker); + RAY_CHECK(workers_.size() <= static_cast(options_.num_workers)); + return worker; +} + +void CoreWorkerProcess::RemoveWorker(std::shared_ptr worker) { + worker->WaitForShutdown(); + if (global_worker_) { + RAY_CHECK(global_worker_ == worker); + } else { + RAY_CHECK(current_core_worker_.lock() == worker); + } + current_core_worker_.reset(); + { + absl::MutexLock lock(&worker_map_mutex_); + workers_.erase(worker->GetWorkerID()); + RAY_LOG(INFO) << "Removed worker " << worker->GetWorkerID(); + } + if (global_worker_ == worker) { + global_worker_ = nullptr; + } +} + +void CoreWorkerProcess::RunTaskExecutionLoop() { + EnsureInitialized(); + RAY_CHECK(instance_->options_.worker_type == WorkerType::WORKER); + if (instance_->options_.num_workers == 1) { + // Run the task loop in the current thread only if the number of workers is 1. + auto worker = + instance_->global_worker_ ? instance_->global_worker_ : instance_->CreateWorker(); + worker->RunTaskExecutionLoop(); + instance_->RemoveWorker(worker); + } else { + std::vector worker_threads; + for (int i = 0; i < instance_->options_.num_workers; i++) { + worker_threads.emplace_back([]() { + auto worker = instance_->CreateWorker(); + worker->RunTaskExecutionLoop(); + instance_->RemoveWorker(worker); + }); + } + for (auto &thread : worker_threads) { + thread.join(); + } + } + + instance_.reset(); +} + +CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_id) + : options_(options), + get_call_site_(RayConfig::instance().record_ref_creation_sites() + ? options_.get_lang_stack + : nullptr), + worker_context_(options_.worker_type, worker_id, options_.job_id), io_work_(io_service_), client_call_manager_(new rpc::ClientCallManager(io_service_)), death_check_timer_(io_service_), internal_timer_(io_service_), - core_worker_server_(WorkerTypeString(worker_type), 0 /* let grpc choose a port */), + core_worker_server_(WorkerTypeString(options_.worker_type), + 0 /* let grpc choose a port */), task_queue_length_(0), num_executed_tasks_(0), task_execution_service_work_(task_execution_service_), - task_execution_callback_(task_execution_callback), resource_ids_(new ResourceMappingType()), grpc_service_(io_service_, *this) { - // Initialize logging if log_dir is passed. Otherwise, it must be initialized - // and cleaned up by the caller. - if (log_dir_ != "") { - std::stringstream app_name; - app_name << LanguageString(language_) << "-" << WorkerTypeString(worker_type_) << "-" - << worker_context_.GetWorkerID(); - RayLog::StartRayLog(app_name.str(), RayLogLevel::INFO, log_dir_); - RayLog::InstallFailureSignalHandler(); - } // Initialize gcs client. if (RayConfig::instance().gcs_service_enabled()) { - gcs_client_ = std::make_shared(gcs_options); + gcs_client_ = std::make_shared(options_.gcs_options); } else { - gcs_client_ = std::make_shared(gcs_options); + gcs_client_ = std::make_shared(options_.gcs_options); } RAY_CHECK_OK(gcs_client_->Connect(io_service_)); + RegisterToGcs(); actor_manager_ = std::unique_ptr(new ActorManager(gcs_client_->Actors())); // Initialize profiler. - profiler_ = std::make_shared(worker_context_, node_ip_address, - io_service_, gcs_client_); + profiler_ = std::make_shared( + worker_context_, options_.node_ip_address, io_service_, gcs_client_); // Initialize task receivers. - if (worker_type_ == WorkerType::WORKER || is_local_mode_) { - RAY_CHECK(task_execution_callback_ != nullptr); + if (options_.worker_type == WorkerType::WORKER || options_.is_local_mode) { + RAY_CHECK(options_.task_execution_callback != nullptr); auto execute_task = std::bind(&CoreWorker::ExecuteTask, this, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); @@ -154,17 +305,18 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, // so that the worker (java/python .etc) can retrieve and handle the error // instead of crashing. auto grpc_client = rpc::NodeManagerWorkerClient::make( - node_ip_address, node_manager_port, *client_call_manager_); + options_.node_ip_address, options_.node_manager_port, *client_call_manager_); ClientID local_raylet_id; local_raylet_client_ = std::shared_ptr(new raylet::RayletClient( - io_service_, std::move(grpc_client), raylet_socket, worker_context_.GetWorkerID(), - (worker_type_ == ray::WorkerType::WORKER), worker_context_.GetCurrentJobID(), - language_, &local_raylet_id, core_worker_server_.GetPort())); + io_service_, std::move(grpc_client), options_.raylet_socket, GetWorkerID(), + (options_.worker_type == ray::WorkerType::WORKER), + worker_context_.GetCurrentJobID(), options_.language, &local_raylet_id, + core_worker_server_.GetPort())); connected_ = true; // Set our own address. RAY_CHECK(!local_raylet_id.IsNil()); - rpc_address_.set_ip_address(node_ip_address); + rpc_address_.set_ip_address(options_.node_ip_address); rpc_address_.set_port(core_worker_server_.GetPort()); rpc_address_.set_raylet_id(local_raylet_id.Binary()); rpc_address_.set_worker_id(worker_context_.GetWorkerID().Binary()); @@ -179,20 +331,21 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, new rpc::CoreWorkerClient(addr, *client_call_manager_)); }); - if (worker_type_ == ray::WorkerType::WORKER) { + if (options_.worker_type == ray::WorkerType::WORKER) { death_check_timer_.expires_from_now(boost::asio::chrono::milliseconds( RayConfig::instance().raylet_death_check_interval_milliseconds())); - death_check_timer_.async_wait(boost::bind(&CoreWorker::CheckForRayletFailure, this)); + death_check_timer_.async_wait( + boost::bind(&CoreWorker::CheckForRayletFailure, this, _1)); } internal_timer_.expires_from_now( boost::asio::chrono::milliseconds(kInternalHeartbeatMillis)); - internal_timer_.async_wait(boost::bind(&CoreWorker::InternalHeartbeat, this)); + internal_timer_.async_wait(boost::bind(&CoreWorker::InternalHeartbeat, this, _1)); io_thread_ = std::thread(&CoreWorker::RunIOService, this); plasma_store_provider_.reset(new CoreWorkerPlasmaStoreProvider( - store_socket, local_raylet_client_, check_signals_, + options_.store_socket, local_raylet_client_, options_.check_signals, /*evict_if_full=*/RayConfig::instance().object_pinning_enabled(), boost::bind(&CoreWorker::TriggerGlobalGC, this), boost::bind(&CoreWorker::CurrentCallSite, this))); @@ -201,8 +354,8 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, RAY_LOG(DEBUG) << "Promoting object to plasma " << obj_id; RAY_CHECK_OK(Put(obj, /*contained_object_ids=*/{}, obj_id, /*pin_object=*/true)); }, - ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_, - check_signals_)); + options_.ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_, + options_.check_signals)); task_manager_.reset(new TaskManager( memory_store_, reference_counter_, actor_manager_, @@ -222,16 +375,17 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, // driver creates an object that is later evicted, we should notify the // user that we're unable to reconstruct the object, since we cannot // rerun the driver. - if (worker_type_ == WorkerType::DRIVER) { + if (options_.worker_type == WorkerType::DRIVER) { TaskSpecBuilder builder; const TaskID task_id = TaskID::ForDriverTask(worker_context_.GetCurrentJobID()); - builder.SetDriverTaskSpec(task_id, language_, worker_context_.GetCurrentJobID(), + builder.SetDriverTaskSpec(task_id, options_.language, + worker_context_.GetCurrentJobID(), TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()), GetCallerId(), rpc_address_); std::shared_ptr data = std::make_shared(); data->mutable_task()->mutable_task_spec()->CopyFrom(builder.Build().GetMessage()); - if (!is_local_mode_) { + if (!options_.is_local_mode) { RAY_CHECK_OK(gcs_client_->Tasks().AsyncAdd(data, nullptr)); } SetCurrentTaskId(task_id); @@ -262,17 +416,9 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, } } -CoreWorker::~CoreWorker() { - io_service_.stop(); - io_thread_.join(); - if (log_dir_ != "") { - RayLog::ShutDownRayLog(); - } -} - void CoreWorker::Shutdown() { io_service_.stop(); - if (worker_type_ == WorkerType::WORKER) { + if (options_.worker_type == WorkerType::WORKER) { task_execution_service_.stop(); } } @@ -356,6 +502,17 @@ void CoreWorker::RunIOService() { io_service_.run(); } +void CoreWorker::WaitForShutdown() { + if (io_thread_.joinable()) { + io_thread_.join(); + } + if (options_.worker_type == WorkerType::WORKER) { + RAY_CHECK(task_execution_service_.stopped()); + } +} + +const WorkerID &CoreWorker::GetWorkerID() const { return worker_context_.GetWorkerID(); } + void CoreWorker::SetCurrentTaskId(const TaskID &task_id) { worker_context_.SetCurrentTaskId(task_id); main_thread_task_id_ = task_id; @@ -375,8 +532,41 @@ void CoreWorker::SetCurrentTaskId(const TaskID &task_id) { } } -void CoreWorker::CheckForRayletFailure() { -// If the raylet fails, we will be reassigned to init (PID=1). +void CoreWorker::RegisterToGcs() { + std::unordered_map worker_info; + const auto &worker_id = GetWorkerID(); + worker_info.emplace("node_ip_address", options_.node_ip_address); + worker_info.emplace("plasma_store_socket", options_.store_socket); + worker_info.emplace("raylet_socket", options_.raylet_socket); + + if (options_.worker_type == WorkerType::DRIVER) { + auto start_time = std::chrono::duration_cast( + std::chrono::system_clock::now().time_since_epoch()) + .count(); + worker_info.emplace("driver_id", worker_id.Binary()); + worker_info.emplace("start_time", std::to_string(start_time)); + if (!options_.driver_name.empty()) { + worker_info.emplace("name", options_.driver_name); + } + } + + if (!options_.stdout_file.empty()) { + worker_info.emplace("stdout_file", options_.stdout_file); + } + if (!options_.stderr_file.empty()) { + worker_info.emplace("stderr_file", options_.stderr_file); + } + + RAY_CHECK_OK(gcs_client_->Workers().AsyncRegisterWorker(options_.worker_type, worker_id, + worker_info, nullptr)); +} + +void CoreWorker::CheckForRayletFailure(const boost::system::error_code &error) { + if (error == boost::asio::error::operation_aborted) { + return; + } + + // If the raylet fails, we will be reassigned to init (PID=1). if (getppid() == 1) { RAY_LOG(ERROR) << "Raylet failed. Shutting down."; Shutdown(); @@ -387,10 +577,15 @@ void CoreWorker::CheckForRayletFailure() { death_check_timer_.expiry() + boost::asio::chrono::milliseconds( RayConfig::instance().raylet_death_check_interval_milliseconds())); - death_check_timer_.async_wait(boost::bind(&CoreWorker::CheckForRayletFailure, this)); + death_check_timer_.async_wait( + boost::bind(&CoreWorker::CheckForRayletFailure, this, _1)); } -void CoreWorker::InternalHeartbeat() { +void CoreWorker::InternalHeartbeat(const boost::system::error_code &error) { + if (error == boost::asio::error::operation_aborted) { + return; + } + absl::MutexLock lock(&mutex_); while (!to_resubmit_.empty() && current_time_ms() > to_resubmit_.front().first) { RAY_CHECK_OK(direct_task_submitter_->SubmitTask(to_resubmit_.front().second)); @@ -398,7 +593,7 @@ void CoreWorker::InternalHeartbeat() { } internal_timer_.expires_at(internal_timer_.expiry() + boost::asio::chrono::milliseconds(kInternalHeartbeatMillis)); - internal_timer_.async_wait(boost::bind(&CoreWorker::InternalHeartbeat, this)); + internal_timer_.async_wait(boost::bind(&CoreWorker::InternalHeartbeat, this, _1)); } std::unordered_map> @@ -445,7 +640,7 @@ void CoreWorker::RegisterOwnershipInfoAndResolveFuture( reference_counter_->AddBorrowedObject(object_id, outer_object_id, owner_id, owner_address); - RAY_CHECK(!owner_id.IsNil() || is_local_mode_); + RAY_CHECK(!owner_id.IsNil() || options_.is_local_mode); // We will ask the owner about the object until the object is // created or we can no longer reach the owner. future_resolver_->ResolveFutureAsync(object_id, owner_id, owner_address); @@ -471,7 +666,7 @@ Status CoreWorker::Put(const RayObject &object, const std::vector &contained_object_ids, const ObjectID &object_id, bool pin_object) { bool object_exists; - if (is_local_mode_) { + if (options_.is_local_mode) { RAY_CHECK(memory_store_->Put(object, object_id)); return Status::OK(); } @@ -505,7 +700,7 @@ Status CoreWorker::Create(const std::shared_ptr &metadata, const size_t worker_context_.GetNextPutIndex(), static_cast(TaskTransportType::DIRECT)); - if (is_local_mode_) { + if (options_.is_local_mode) { *data = std::make_shared(data_size); } else { RAY_RETURN_NOT_OK( @@ -523,7 +718,7 @@ Status CoreWorker::Create(const std::shared_ptr &metadata, const size_t Status CoreWorker::Create(const std::shared_ptr &metadata, const size_t data_size, const ObjectID &object_id, std::shared_ptr *data) { - if (is_local_mode_) { + if (options_.is_local_mode) { return Status::NotImplemented( "Creating an object with a pre-existing ObjectID is not supported in local mode"); } else { @@ -791,7 +986,7 @@ TaskID CoreWorker::GetCallerId() const { Status CoreWorker::PushError(const JobID &job_id, const std::string &type, const std::string &error_message, double timestamp) { - if (is_local_mode_) { + if (options_.is_local_mode) { RAY_LOG(ERROR) << "Pushed Error with JobID: " << job_id << " of type: " << type << " with message: " << error_message << " at time: " << timestamp; return Status::OK(); @@ -831,7 +1026,7 @@ Status CoreWorker::SubmitTask(const RayFunction &function, rpc_address_, function, args, task_options.num_returns, task_options.resources, required_resources, return_ids); TaskSpecification task_spec = builder.Build(); - if (is_local_mode_) { + if (options_.is_local_mode) { return ExecuteTaskLocalMode(task_spec); } else { task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec, @@ -866,7 +1061,7 @@ Status CoreWorker::CreateActor(const RayFunction &function, *return_actor_id = actor_id; TaskSpecification task_spec = builder.Build(); Status status; - if (is_local_mode_) { + if (options_.is_local_mode) { status = ExecuteTaskLocalMode(task_spec); } else { task_manager_->AddPendingTask( @@ -914,7 +1109,7 @@ Status CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &f // Submit task. Status status; TaskSpecification task_spec = builder.Build(); - if (is_local_mode_) { + if (options_.is_local_mode) { return ExecuteTaskLocalMode(task_spec, actor_id); } task_manager_->AddPendingTask(GetCallerId(), rpc_address_, task_spec, @@ -1055,7 +1250,7 @@ std::unique_ptr CoreWorker::CreateProfileEvent( new worker::ProfileEvent(profiler_, event_type)); } -void CoreWorker::StartExecutingTasks() { task_execution_service_.run(); } +void CoreWorker::RunTaskExecutionLoop() { task_execution_service_.run(); } Status CoreWorker::AllocateReturnObjects( const std::vector &object_ids, const std::vector &data_sizes, @@ -1066,7 +1261,7 @@ Status CoreWorker::AllocateReturnObjects( RAY_CHECK(object_ids.size() == data_sizes.size()); return_objects->resize(object_ids.size(), nullptr); - rpc::Address owner_address(is_local_mode_ + rpc::Address owner_address(options_.is_local_mode ? rpc::Address() : worker_context_.GetCurrentTask()->CallerAddress()); @@ -1083,8 +1278,9 @@ Status CoreWorker::AllocateReturnObjects( } // Allocate a buffer for the return object. - if (is_local_mode_ || static_cast(data_sizes[i]) < - RayConfig::instance().max_direct_call_object_size()) { + if (options_.is_local_mode || + static_cast(data_sizes[i]) < + RayConfig::instance().max_direct_call_object_size()) { data_buffer = std::make_shared(data_sizes[i]); } else { RAY_RETURN_NOT_OK( @@ -1113,7 +1309,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, resource_ids_ = resource_ids; } - if (!is_local_mode_) { + if (!options_.is_local_mode) { worker_context_.SetCurrentTask(task_spec); SetCurrentTaskId(task_spec.TaskId()); } @@ -1162,13 +1358,17 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, SetCallerCreationTimestamp(); } - status = task_execution_callback_( + // Because we support concurrent actor calls, we need to update the + // worker ID for the current thread. + CoreWorkerProcess::SetCurrentThreadWorkerId(GetWorkerID()); + + status = options_.task_execution_callback( task_type, func, task_spec.GetRequiredResources().GetResourceMap(), args, - arg_reference_ids, return_ids, return_objects, worker_context_.GetWorkerID()); + arg_reference_ids, return_ids, return_objects); absl::optional caller_address( - is_local_mode_ ? absl::optional() - : worker_context_.GetCurrentTask()->CallerAddress()); + options_.is_local_mode ? absl::optional() + : worker_context_.GetCurrentTask()->CallerAddress()); for (size_t i = 0; i < return_objects->size(); i++) { // The object is nullptr if it already existed in the object store. if (!return_objects->at(i)) { @@ -1196,7 +1396,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, RAY_LOG(DEBUG) << "Decrementing ref for borrowed ID " << borrowed_id; reference_counter_->RemoveLocalReference(borrowed_id, &deleted); } - if (ref_counting_enabled_) { + if (options_.ref_counting_enabled) { memory_store_->Delete(deleted); } @@ -1210,7 +1410,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, "reference counting, and may cause problems in the object store."; } - if (!is_local_mode_) { + if (!options_.is_local_mode) { SetCurrentTaskId(TaskID::Nil()); worker_context_.ResetCurrentTask(task_spec); } @@ -1263,7 +1463,7 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, // Direct call type objects that weren't inlined have been promoted to plasma. // We need to put an OBJECT_IN_PLASMA error here so the subsequent call to Get() // properly redirects to the plasma store. - if (task.ArgId(i, 0).IsDirectCallType() && !is_local_mode_) { + if (task.ArgId(i, 0).IsDirectCallType() && !options_.is_local_mode) { RAY_UNUSED(memory_store_->Put(RayObject(rpc::ErrorType::OBJECT_IN_PLASMA), task.ArgId(i, 0))); } @@ -1309,7 +1509,7 @@ Status CoreWorker::BuildArgsForExecutor(const TaskSpecification &task, // Fetch by-reference arguments directly from the plasma store. bool got_exception = false; absl::flat_hash_map> result_map; - if (is_local_mode_) { + if (options_.is_local_mode) { RAY_RETURN_NOT_OK( memory_store_->Get(by_ref_ids, -1, worker_context_, &result_map, &got_exception)); } else { @@ -1472,7 +1672,16 @@ void CoreWorker::HandleKillActor(const rpc::KillActorRequest &request, if (request.no_reconstruction()) { RAY_IGNORE_EXPR(local_raylet_client_->Disconnect()); } - if (log_dir_ != "") { + if (options_.num_workers > 1) { + // TODO (kfstorm): Should we add some kind of check before sending the killing + // request? + RAY_LOG(ERROR) + << "Killing an actor which is running in a worker process with multiple " + "workers will also kill other actors in this process. To avoid this, " + "please create the Java actor with some dynamic options to make it being " + "hosted in a dedicated worker process."; + } + if (options_.log_dir != "") { RayLog::ShutDownRayLog(); } exit(1); @@ -1522,8 +1731,8 @@ void CoreWorker::HandleGetCoreWorkerStats(const rpc::GetCoreWorkerStatsRequest & void CoreWorker::HandleLocalGC(const rpc::LocalGCRequest &request, rpc::LocalGCReply *reply, rpc::SendReplyCallback send_reply_callback) { - if (gc_collect_ != nullptr) { - gc_collect_(); + if (options_.gc_collect != nullptr) { + options_.gc_collect(); send_reply_callback(Status::OK(), nullptr, nullptr); } else { send_reply_callback(Status::NotImplemented("GC callback not defined"), nullptr, @@ -1573,7 +1782,7 @@ void CoreWorker::HandlePlasmaObjectReady(const rpc::PlasmaObjectReadyRequest &re void CoreWorker::SetActorId(const ActorID &actor_id) { absl::MutexLock lock(&mutex_); - if (!is_local_mode_) { + if (!options_.is_local_mode) { RAY_CHECK(actor_id_.IsNil()); } actor_id_ = actor_id; diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 61d286060..21cfe5f3f 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -50,10 +50,9 @@ namespace ray { -/// The root class that contains all the core and language-independent functionalities -/// of the worker. This class is supposed to be used to implement app-language (Java, -/// Python, etc) workers. -class CoreWorker : public rpc::CoreWorkerServiceHandler { +class CoreWorker; + +struct CoreWorkerOptions { // Callback that must be implemented and provided by the language-specific worker // frontend to execute tasks and return their results. using TaskExecutionCallback = std::function> &args, const std::vector &arg_reference_ids, const std::vector &return_ids, - std::vector> *results, const ray::WorkerID &worker_id)>; + std::vector> *results)>; + /// Type of this worker (i.e., DRIVER or WORKER). + WorkerType worker_type; + /// Application language of this worker (i.e., PYTHON or JAVA). + Language language; + /// Object store socket to connect to. + std::string store_socket; + /// Raylet socket to connect to. + std::string raylet_socket; + /// Job ID of this worker. + JobID job_id; + /// Options for the GCS client. + gcs::GcsClientOptions gcs_options; + /// Directory to write logs to. If this is empty, logs won't be written to a file. + std::string log_dir; + /// If false, will not call `RayLog::InstallFailureSignalHandler()`. + bool install_failure_signal_handler; + /// IP address of the node. + std::string node_ip_address; + /// Port of the local raylet. + int node_manager_port; + /// The name of the driver. + std::string driver_name; + /// The stdout file of this process. + std::string stdout_file; + /// The stderr file of this process. + std::string stderr_file; + /// Language worker callback to execute tasks. + TaskExecutionCallback task_execution_callback; + /// Application-language callback to check for signals that have been received + /// since calling into C++. This will be called periodically (at least every + /// 1s) during long-running operations. If the function returns anything but StatusOK, + /// any long-running operations in the core worker will short circuit and return that + /// status. + std::function check_signals; + /// Application-language callback to trigger garbage collection in the language + /// runtime. This is required to free distributed references that may otherwise + /// be held up in garbage objects. + std::function gc_collect; + /// Language worker callback to get the current call stack. + std::function get_lang_stack; + /// Whether to enable object ref counting. + bool ref_counting_enabled; + /// Is local mode being used. + bool is_local_mode; + /// The number of workers to be started in the current process. + int num_workers; +}; + +/// Lifecycle management of one or more `CoreWorker` instances in a process. +/// +/// To start a driver in the current process: +/// CoreWorkerOptions options = { +/// WorkerType::DRIVER, // worker_type +/// ..., // other arguments +/// 1, // num_workers +/// }; +/// CoreWorkerProcess::Initialize(options); +/// +/// To shutdown a driver in the current process: +/// CoreWorkerProcess::Shutdown(); +/// +/// To start one or more workers in the current process: +/// CoreWorkerOptions options = { +/// WorkerType::WORKER, // worker_type +/// ..., // other arguments +/// num_workers, // num_workers +/// }; +/// CoreWorkerProcess::Initialize(options); +/// ... // Do other stuff +/// CoreWorkerProcess::RunTaskExecutionLoop(); +/// +/// To shutdown a worker in the current process, return a system exit status (with status +/// code `IntentionalSystemExit` or `UnexpectedSystemExit`) in the task execution +/// callback. +/// +/// If more than 1 worker is started, only the threads which invoke the +/// `task_execution_callback` will be automatically associated with the corresponding +/// worker. If you started your own threads and you want to use core worker APIs in these +/// threads, remember to call `CoreWorkerProcess::SetCurrentThreadWorkerId(worker_id)` +/// once in the new thread before calling core worker APIs, to associate the current +/// thread with a worker. You can obtain the worker ID via +/// `CoreWorkerProcess::GetCoreWorker()->GetWorkerID()`. Currently a Java worker process +/// starts multiple workers by default, but can be configured to start only 1 worker by +/// overwriting the internal config `num_workers_per_process_java`. +/// +/// If only 1 worker is started (either because the worker type is driver, or the +/// `num_workers` in `CoreWorkerOptions` is set to 1), all threads will be automatically +/// associated to the only worker. Then no need to call `SetCurrentThreadWorkerId` in +/// your own threads. Currently a Python worker process starts only 1 worker. +class CoreWorkerProcess { + public: + /// + /// Public methods used in both DRIVER and WORKER mode. + /// + + /// Initialize core workers at the process level. + /// + /// \param[in] options The various initialization options. + static void Initialize(const CoreWorkerOptions &options); + + /// Get the core worker associated with the current thread. + /// NOTE (kfstorm): Here we return a reference instead of a `shared_ptr` to make sure + /// `CoreWorkerProcess` has full control of the destruction timing of `CoreWorker`. + static CoreWorker &GetCoreWorker(); + + /// Set the core worker associated with the current thread by worker ID. + /// Currently used by Java worker only. + /// + /// \param worker_id The worker ID of the core worker instance. + static void SetCurrentThreadWorkerId(const WorkerID &worker_id); + + /// Whether the current process has been initialized for core worker. + static bool IsInitialized(); + + /// + /// Public methods used in DRIVER mode only. + /// + + /// Shutdown the driver completely at the process level. + static void Shutdown(); + + /// + /// Public methods used in WORKER mode only. + /// + + /// Start receiving and executing tasks. + static void RunTaskExecutionLoop(); + + // The destructor is not to be used as a public API, but it's required by smart + // pointers. + ~CoreWorkerProcess(); + + private: + /// Create an `CoreWorkerProcess` with proper options. + /// + /// \param[in] options The various initialization options. + CoreWorkerProcess(const CoreWorkerOptions &options); + + /// Check that the core worker environment is initialized for this process. + /// + /// \return Void. + static void EnsureInitialized(); + + /// Get the `CoreWorker` instance by worker ID. + /// + /// \param[in] workerId The worker ID. + /// \return The `CoreWorker` instance. + std::shared_ptr GetWorker(const WorkerID &worker_id) const + LOCKS_EXCLUDED(worker_map_mutex_); + + /// Create a new `CoreWorker` instance. + /// + /// \return The newly created `CoreWorker` instance. + std::shared_ptr CreateWorker() LOCKS_EXCLUDED(worker_map_mutex_); + + /// Remove an existing `CoreWorker` instance. + /// + /// \param[in] The existing `CoreWorker` instance. + /// \return Void. + void RemoveWorker(std::shared_ptr worker) LOCKS_EXCLUDED(worker_map_mutex_); + + /// The global instance of `CoreWorkerProcess`. + static std::unique_ptr instance_; + + /// The various options. + const CoreWorkerOptions options_; + + /// The core worker instance associated with the current thread. + /// Use weak_ptr here to avoid memory leak due to multi-threading. + static thread_local std::weak_ptr current_core_worker_; + + /// The only core worker instance, if the number of workers is 1. + std::shared_ptr global_worker_; + + /// The worker ID of the global worker, if the number of workers is 1. + const WorkerID global_worker_id_; + + /// Map from worker ID to worker. + std::unordered_map> workers_ + GUARDED_BY(worker_map_mutex_); + + /// To protect accessing the `workers_` map. + mutable absl::Mutex worker_map_mutex_; +}; + +/// The root class that contains all the core and language-independent functionalities +/// of the worker. This class is supposed to be used to implement app-language (Java, +/// Python, etc) workers. +class CoreWorker : public rpc::CoreWorkerServiceHandler { public: /// Construct a CoreWorker instance. /// - /// \param[in] worker_type Type of this worker. - /// \param[in] language Language of this worker. - /// \param[in] store_socket Object store socket to connect to. - /// \param[in] raylet_socket Raylet socket to connect to. - /// \param[in] job_id Job ID of this worker. - /// \param[in] gcs_options Options for the GCS client. - /// \param[in] log_dir Directory to write logs to. If this is empty, logs - /// won't be written to a file. - /// \param[in] node_ip_address IP address of the node. - /// \param[in] node_manager_port Port of the local raylet. - /// \param[in] task_execution_callback Language worker callback to execute tasks. - /// \param[in] check_signals Language worker function to check for signals and handle - /// them. If the function returns anything but StatusOK, any long-running - /// operations in the core worker will short circuit and return that status. - /// \param[in] ref_counting_enabled Whether to enable object ref counting. + /// \param[in] options The various initialization options. + /// \param[in] worker_id ID of this worker. + CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_id); + + CoreWorker(CoreWorker const &) = delete; + void operator=(CoreWorker const &other) = delete; + + /// + /// Public methods used by `CoreWorkerProcess` and `CoreWorker` itself. /// - /// NOTE(zhijunfu): the constructor would throw if a failure happens. - CoreWorker(const WorkerType worker_type, const Language language, - const std::string &store_socket, const std::string &raylet_socket, - const JobID &job_id, const gcs::GcsClientOptions &gcs_options, - const std::string &log_dir, const std::string &node_ip_address, - int node_manager_port, const TaskExecutionCallback &task_execution_callback, - std::function check_signals = nullptr, - std::function gc_collect = nullptr, - std::function get_lang_stack = nullptr, - bool ref_counting_enabled = false, bool local_mode = false); - - virtual ~CoreWorker(); - - void Exit(bool intentional); + /// Gracefully disconnect the worker from other components of ray. e.g. Raylet. + /// If this function is called during shutdown, Raylet will treat it as an intentional + /// disconnect. + /// + /// \return Void. void Disconnect(); - WorkerType GetWorkerType() const { return worker_type_; } + /// Shut down the worker completely. + /// + /// \return void. + void Shutdown(); - Language GetLanguage() const { return language_; } + /// Block the current thread until the worker is shut down. + void WaitForShutdown(); + + /// Start receiving and executing tasks. + /// \return void. + void RunTaskExecutionLoop(); + + const WorkerID &GetWorkerID() const; + + WorkerType GetWorkerType() const { return options_.worker_type; } + + Language GetLanguage() const { return options_.language; } WorkerContext &GetWorkerContext() { return worker_context_; } - raylet::RayletClient &GetRayletClient() { return *local_raylet_client_; } - const TaskID &GetCurrentTaskId() const { return worker_context_.GetCurrentTaskID(); } - void SetCurrentTaskId(const TaskID &task_id); - const JobID &GetCurrentJobId() const { return worker_context_.GetCurrentJobID(); } - void SetActorId(const ActorID &actor_id); - void SetWebuiDisplay(const std::string &key, const std::string &message); void SetActorTitle(const std::string &title); @@ -139,7 +320,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { std::vector deleted; reference_counter_->RemoveLocalReference(object_id, &deleted); // TOOD(ilr): better way of keeping an object from being deleted - if (ref_counting_enabled_ && !is_local_mode_) { + if (options_.ref_counting_enabled && !options_.is_local_mode) { memory_store_->Delete(deleted); } } @@ -448,10 +629,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// Create a profile event with a reference to the core worker's profiler. std::unique_ptr CreateProfileEvent(const std::string &event_type); - /// Start receiving and executing tasks. - /// \return void. - void StartExecutingTasks(); - + public: /// Allocate the return objects for an executing task. The caller should write into the /// data buffers of the allocated buffers. /// @@ -566,18 +744,25 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { void SubscribeToPlasmaAdd(const ObjectID &object_id); private: + void SetCurrentTaskId(const TaskID &task_id); + + void SetActorId(const ActorID &actor_id); + /// Run the io_service_ event loop. This should be called in a background thread. void RunIOService(); - /// Shut down the worker completely. - /// \return void. - void Shutdown(); + /// (WORKER mode only) Exit the worker. This is the entrypoint used to shutdown a + /// worker. + void Exit(bool intentional); + + /// Register this worker or driver to GCS. + void RegisterToGcs(); /// Check if the raylet has failed. If so, shutdown. - void CheckForRayletFailure(); + void CheckForRayletFailure(const boost::system::error_code &error); /// Heartbeat for internal bookkeeping. - void InternalHeartbeat(); + void InternalHeartbeat(const boost::system::error_code &error); /// /// Private methods related to task submission. @@ -682,30 +867,7 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { } } - /// Type of this worker (i.e., DRIVER or WORKER). - const WorkerType worker_type_; - - /// Application language of this worker (i.e., PYTHON or JAVA). - const Language language_; - - /// Directory where log files are written. - const std::string log_dir_; - - /// Whether local reference counting is enabled. - const bool ref_counting_enabled_; - - /// Is local mode being used. - const bool is_local_mode_; - - /// Application-language callback to check for signals that have been received - /// since calling into C++. This will be called periodically (at least every - /// 1s) during long-running operations. - std::function check_signals_; - - /// Application-language callback to trigger garbage collection in the language - /// runtime. This is required to free distributed references that may otherwise - /// be held up in garbage objects. - std::function gc_collect_; + const CoreWorkerOptions options_; /// Callback to get the current language (e.g., Python) call site. std::function get_call_site_; @@ -843,9 +1005,6 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// Profiler including a background thread that pushes profiling events to the GCS. std::shared_ptr profiler_; - /// Task execution callback. - TaskExecutionCallback task_execution_callback_; - /// A map from resource name to the resource IDs that are currently reserved /// for this worker. Each pair consists of the resource ID and the fraction /// of that resource allocated for this worker. This is set on task assignment. diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index 38b2765cf..57de4fdcf 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -82,7 +82,6 @@ jfieldID java_native_ray_object_metadata; jclass java_task_executor_class; jmethodID java_task_executor_execute; -jmethodID java_task_executor_get; JavaVM *jvm; @@ -197,9 +196,6 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { env->GetMethodID(java_task_executor_class, "execute", "(Ljava/util/List;Ljava/util/List;)Ljava/util/List;"); - java_task_executor_get = env->GetStaticMethodID( - java_task_executor_class, "get", "([B)Lorg/ray/runtime/task/TaskExecutor;"); - return CURRENT_JNI_VERSION; } diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index d3a318b74..8a0f7465f 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -141,9 +141,6 @@ extern jclass java_task_executor_class; /// execute method of TaskExecutor class extern jmethodID java_task_executor_execute; -/// The `get` method in TaskExecutor class -extern jmethodID java_task_executor_get; - #define CURRENT_JNI_VERSION JNI_VERSION_1_8 extern JavaVM *jvm; diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc index 8679049b3..e4ec7ab90 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc @@ -20,6 +20,7 @@ #include "ray/core_worker/lib/java/jni_utils.h" thread_local JNIEnv *local_env = nullptr; +jobject java_task_executor = nullptr; inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env, jobject gcs_client_options) { @@ -36,15 +37,20 @@ inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env, extern "C" { #endif -JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWorker( - JNIEnv *env, jclass, jint workerMode, jstring storeSocket, jstring rayletSocket, - jstring nodeIpAddress, jint nodeManagerPort, jbyteArray jobId, - jobject gcsClientOptions) { - auto native_store_socket = JavaStringToNativeString(env, storeSocket); - auto native_raylet_socket = JavaStringToNativeString(env, rayletSocket); - auto job_id = JavaByteArrayToId(env, jobId); - auto gcs_client_options = ToGcsClientOptions(env, gcsClientOptions); - auto node_ip_address = JavaStringToNativeString(env, nodeIpAddress); +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitialize( + JNIEnv *env, jclass, jint workerMode, jstring nodeIpAddress, jint nodeManagerPort, + jstring driverName, jstring storeSocket, jstring rayletSocket, jbyteArray jobId, + jobject gcsClientOptions, jint numWorkersPerProcess, jstring logDir, + jobject rayletConfigParameters) { + auto raylet_config = JavaMapToNativeMap( + env, rayletConfigParameters, + [](JNIEnv *env, jobject java_key) { + return JavaStringToNativeString(env, (jstring)java_key); + }, + [](JNIEnv *env, jobject java_value) { + return JavaStringToNativeString(env, (jstring)java_value); + }); + RayConfig::instance().initialize(raylet_config); auto task_execution_callback = [](ray::TaskType task_type, const ray::RayFunction &ray_function, @@ -52,8 +58,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork const std::vector> &args, const std::vector &arg_reference_ids, const std::vector &return_ids, - std::vector> *results, - const ray::WorkerID &worker_id) { + std::vector> *results) { JNIEnv *env = local_env; if (!env) { // Attach the native thread to JVM. @@ -64,12 +69,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork } RAY_CHECK(env); - - auto worker_id_bytes = IdToJavaByteArray(env, worker_id); - jobject local_java_task_executor = env->CallStaticObjectMethod( - java_task_executor_class, java_task_executor_get, worker_id_bytes); - - RAY_CHECK(local_java_task_executor); + RAY_CHECK(java_task_executor); // convert RayFunction jobject ray_function_array_list = NativeRayFunctionDescriptorToJavaStringList( env, ray_function.GetFunctionDescriptor()); @@ -80,7 +80,7 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork // invoke Java method jobject java_return_objects = - env->CallObjectMethod(local_java_task_executor, java_task_executor_execute, + env->CallObjectMethod(java_task_executor, java_task_executor_execute, ray_function_array_list, args_array_list); RAY_CHECK_JAVA_EXCEPTION(env); std::vector> return_objects; @@ -99,81 +99,70 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork return ray::Status::OK(); }; - try { - auto core_worker = new ray::CoreWorker( - static_cast(workerMode), ::Language::JAVA, native_store_socket, - native_raylet_socket, job_id, gcs_client_options, /*log_dir=*/"", node_ip_address, - nodeManagerPort, task_execution_callback); - return reinterpret_cast(core_worker); - } catch (const std::exception &e) { - std::ostringstream oss; - oss << "Failed to construct core worker: " << e.what(); - THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, ray::Status::Invalid(oss.str()), 0); - return 0; // To make compiler no complain - } + ray::CoreWorkerOptions options = { + static_cast(workerMode), // worker_type + ray::Language::JAVA, // langauge + JavaStringToNativeString(env, storeSocket), // store_socket + JavaStringToNativeString(env, rayletSocket), // raylet_socket + JavaByteArrayToId(env, jobId), // job_id + ToGcsClientOptions(env, gcsClientOptions), // gcs_options + JavaStringToNativeString(env, logDir), // log_dir + // TODO (kfstorm): JVM would crash if install_failure_signal_handler was set to true + false, // install_failure_signal_handler + JavaStringToNativeString(env, nodeIpAddress), // node_ip_address + static_cast(nodeManagerPort), // node_manager_port + JavaStringToNativeString(env, driverName), // driver_name + "", // stdout_file + "", // stderr_file + task_execution_callback, // task_execution_callback + nullptr, // check_signals + nullptr, // gc_collect + nullptr, // get_lang_stack + false, // ref_counting_enabled + false, // is_local_mode + static_cast(numWorkersPerProcess), // num_workers + }; + + ray::CoreWorkerProcess::Initialize(options); } JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor( - JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer) { - local_env = env; - auto core_worker = reinterpret_cast(nativeCoreWorkerPointer); - core_worker->StartExecutingTasks(); - local_env = nullptr; + JNIEnv *env, jclass o, jobject javaTaskExecutor) { + java_task_executor = javaTaskExecutor; + ray::CoreWorkerProcess::RunTaskExecutionLoop(); + java_task_executor = nullptr; } -JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeDestroyCoreWorker( - JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer) { - auto core_worker = reinterpret_cast(nativeCoreWorkerPointer); - core_worker->Disconnect(); - delete core_worker; -} - -JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup( - JNIEnv *env, jclass, jstring logDir, jobject rayletConfigParameters) { - std::string log_dir = JavaStringToNativeString(env, logDir); - ray::RayLog::StartRayLog("java_worker", ray::RayLogLevel::INFO, log_dir); - // TODO (kfstorm): We can't InstallFailureSignalHandler here, because JVM already - // installed its own signal handler. It's possible to fix this by chaining signal - // handlers. But it's not easy. See - // https://docs.oracle.com/javase/9/troubleshoot/handle-signals-and-exceptions.htm. - auto raylet_config = JavaMapToNativeMap( - env, rayletConfigParameters, - [](JNIEnv *env, jobject java_key) { - return JavaStringToNativeString(env, (jstring)java_key); - }, - [](JNIEnv *env, jobject java_value) { - return JavaStringToNativeString(env, (jstring)java_value); - }); - RayConfig::instance().initialize(raylet_config); -} - -JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdownHook(JNIEnv *, - jclass) { - ray::RayLog::ShutDownRayLog(); +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdown(JNIEnv *env, + jclass o) { + ray::CoreWorkerProcess::Shutdown(); } JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetResource( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jstring resourceName, - jdouble capacity, jbyteArray nodeId) { + JNIEnv *env, jclass, jstring resourceName, jdouble capacity, jbyteArray nodeId) { const auto node_id = JavaByteArrayToId(env, nodeId); const char *native_resource_name = env->GetStringUTFChars(resourceName, JNI_FALSE); - auto status = - reinterpret_cast(nativeCoreWorkerPointer) - ->SetResource(native_resource_name, static_cast(capacity), node_id); + auto status = ray::CoreWorkerProcess::GetCoreWorker().SetResource( + native_resource_name, static_cast(capacity), node_id); env->ReleaseStringUTFChars(resourceName, native_resource_name); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeKillActor( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray actorId, - jboolean noReconstruction) { - auto core_worker = reinterpret_cast(nativeCoreWorkerPointer); - auto status = core_worker->KillActor(JavaByteArrayToId(env, actorId), - /*force_kill=*/true, noReconstruction); + JNIEnv *env, jclass, jbyteArray actorId, jboolean noReconstruction) { + auto status = ray::CoreWorkerProcess::GetCoreWorker().KillActor( + JavaByteArrayToId(env, actorId), + /*force_kill=*/true, noReconstruction); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetCoreWorker( + JNIEnv *env, jclass, jbyteArray workerId) { + const auto worker_id = JavaByteArrayToId(env, workerId); + ray::CoreWorkerProcess::SetCurrentThreadWorkerId(worker_id); +} + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h index e6dadede5..0e642dde4 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h @@ -23,61 +23,55 @@ extern "C" { #endif /* * Class: org_ray_runtime_RayNativeRuntime - * Method: nativeInitCoreWorker + * Method: nativeInitialize * Signature: - * (ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;I[BLorg/ray/runtime/gcs/GcsClientOptions;)J + * (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLorg/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;Ljava/util/Map;)V */ -JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWorker( - JNIEnv *, jclass, jint, jstring, jstring, jstring, jint, jbyteArray, jobject); +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitialize( + JNIEnv *, jclass, jint, jstring, jint, jstring, jstring, jstring, jbyteArray, jobject, + jint, jstring, jobject); /* * Class: org_ray_runtime_RayNativeRuntime * Method: nativeRunTaskExecutor - * Signature: (J)V + * Signature: (Lorg/ray/runtime/task/TaskExecutor;)V */ JNIEXPORT void JNICALL -Java_org_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor(JNIEnv *, jclass, jlong); +Java_org_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor(JNIEnv *, jclass, jobject); /* * Class: org_ray_runtime_RayNativeRuntime - * Method: nativeDestroyCoreWorker - * Signature: (J)V - */ -JNIEXPORT void JNICALL -Java_org_ray_runtime_RayNativeRuntime_nativeDestroyCoreWorker(JNIEnv *, jclass, jlong); - -/* - * Class: org_ray_runtime_RayNativeRuntime - * Method: nativeSetup - * Signature: (Ljava/lang/String;Ljava/util/Map;)V - */ -JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup(JNIEnv *, jclass, - jstring, - jobject); - -/* - * Class: org_ray_runtime_RayNativeRuntime - * Method: nativeShutdownHook + * Method: nativeShutdown * Signature: ()V */ -JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdownHook(JNIEnv *, - jclass); +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdown(JNIEnv *, + jclass); /* * Class: org_ray_runtime_RayNativeRuntime * Method: nativeSetResource - * Signature: (JLjava/lang/String;D[B)V + * Signature: (Ljava/lang/String;D[B)V */ JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetResource( - JNIEnv *, jclass, jlong, jstring, jdouble, jbyteArray); + JNIEnv *, jclass, jstring, jdouble, jbyteArray); /* * Class: org_ray_runtime_RayNativeRuntime * Method: nativeKillActor - * Signature: (J[BZ)V + * Signature: ([BZ)V */ -JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeKillActor( - JNIEnv *, jclass, jlong, jbyteArray, jboolean); +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeKillActor(JNIEnv *, + jclass, + jbyteArray, + jboolean); + +/* + * Class: org_ray_runtime_RayNativeRuntime + * Method: nativeSetCoreWorker + * Signature: ([B)V + */ +JNIEXPORT void JNICALL +Java_org_ray_runtime_RayNativeRuntime_nativeSetCoreWorker(JNIEnv *, jclass, jbyteArray); #ifdef __cplusplus } diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc index 9d631ebf3..dde1acbf1 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc @@ -15,47 +15,44 @@ #include "ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h" #include #include "ray/common/id.h" +#include "ray/core_worker/actor_handle.h" #include "ray/core_worker/common.h" #include "ray/core_worker/core_worker.h" #include "ray/core_worker/lib/java/jni_utils.h" -inline ray::CoreWorker &GetCoreWorker(jlong nativeCoreWorkerPointer) { - return *reinterpret_cast(nativeCoreWorkerPointer); -} - #ifdef __cplusplus extern "C" { #endif JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage( - JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray actorId) { + JNIEnv *env, jclass o, jbyteArray actorId) { auto actor_id = JavaByteArrayToId(env, actorId); ray::ActorHandle *native_actor_handle = nullptr; - auto status = GetCoreWorker(nativeCoreWorkerPointer) - .GetActorHandle(actor_id, &native_actor_handle); + auto status = ray::CoreWorkerProcess::GetCoreWorker().GetActorHandle( + actor_id, &native_actor_handle); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, false); return native_actor_handle->ActorLanguage(); } JNIEXPORT jobject JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetActorCreationTaskFunctionDescriptor( - JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray actorId) { + JNIEnv *env, jclass o, jbyteArray actorId) { auto actor_id = JavaByteArrayToId(env, actorId); ray::ActorHandle *native_actor_handle = nullptr; - auto status = GetCoreWorker(nativeCoreWorkerPointer) - .GetActorHandle(actor_id, &native_actor_handle); + auto status = ray::CoreWorkerProcess::GetCoreWorker().GetActorHandle( + actor_id, &native_actor_handle); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); auto function_descriptor = native_actor_handle->ActorCreationTaskFunctionDescriptor(); return NativeRayFunctionDescriptorToJavaStringList(env, function_descriptor); } JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeSerialize( - JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray actorId) { + JNIEnv *env, jclass o, jbyteArray actorId) { auto actor_id = JavaByteArrayToId(env, actorId); std::string output; ObjectID actor_handle_id; - ray::Status status = GetCoreWorker(nativeCoreWorkerPointer) - .SerializeActorHandle(actor_id, &output, &actor_handle_id); + ray::Status status = ray::CoreWorkerProcess::GetCoreWorker().SerializeActorHandle( + actor_id, &output, &actor_handle_id); jbyteArray bytes = env->NewByteArray(output.size()); env->SetByteArrayRegion(bytes, 0, output.size(), reinterpret_cast(output.c_str())); @@ -63,13 +60,13 @@ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeSer } JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeDeserialize( - JNIEnv *env, jclass o, jlong nativeCoreWorkerPointer, jbyteArray data) { + JNIEnv *env, jclass o, jbyteArray data) { auto buffer = JavaByteArrayToNativeBuffer(env, data); RAY_CHECK(buffer->Size() > 0); auto binary = std::string(reinterpret_cast(buffer->Data()), buffer->Size()); auto actor_id = - GetCoreWorker(nativeCoreWorkerPointer) - .DeserializeAndRegisterActorHandle(binary, /*outer_object_id=*/ObjectID::Nil()); + ray::CoreWorkerProcess::GetCoreWorker().DeserializeAndRegisterActorHandle( + binary, /*outer_object_id=*/ObjectID::Nil()); return IdToJavaByteArray(env, actor_id); } diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h index 26260b5a9..b04dc4fbe 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.h @@ -24,35 +24,35 @@ extern "C" { /* * Class: org_ray_runtime_actor_NativeRayActor * Method: nativeGetLanguage - * Signature: (J[B)I + * Signature: ([B)I */ -JNIEXPORT jint JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage( - JNIEnv *, jclass, jlong, jbyteArray); +JNIEXPORT jint JNICALL +Java_org_ray_runtime_actor_NativeRayActor_nativeGetLanguage(JNIEnv *, jclass, jbyteArray); /* * Class: org_ray_runtime_actor_NativeRayActor * Method: nativeGetActorCreationTaskFunctionDescriptor - * Signature: (J[B)Ljava/util/List; + * Signature: ([B)Ljava/util/List; */ JNIEXPORT jobject JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeGetActorCreationTaskFunctionDescriptor( - JNIEnv *, jclass, jlong, jbyteArray); + JNIEnv *, jclass, jbyteArray); /* * Class: org_ray_runtime_actor_NativeRayActor * Method: nativeSerialize - * Signature: (J[B)[B + * Signature: ([B)[B */ -JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeSerialize( - JNIEnv *, jclass, jlong, jbyteArray); +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_actor_NativeRayActor_nativeSerialize(JNIEnv *, jclass, jbyteArray); /* * Class: org_ray_runtime_actor_NativeRayActor * Method: nativeDeserialize - * Signature: (J[B)[B + * Signature: ([B)[B */ -JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeDeserialize( - JNIEnv *, jclass, jlong, jbyteArray); +JNIEXPORT jbyteArray JNICALL +Java_org_ray_runtime_actor_NativeRayActor_nativeDeserialize(JNIEnv *, jclass, jbyteArray); #ifdef __cplusplus } diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.cc b/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.cc index 211f23ff7..780834d5e 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.cc @@ -19,51 +19,48 @@ #include "ray/core_worker/core_worker.h" #include "ray/core_worker/lib/java/jni_utils.h" -inline ray::WorkerContext &GetWorkerContextFromPointer(jlong nativeCoreWorkerPointer) { - return reinterpret_cast(nativeCoreWorkerPointer)->GetWorkerContext(); -} - #ifdef __cplusplus extern "C" { #endif JNIEXPORT jint JNICALL -Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskType( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) { - auto task_spec = GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetCurrentTask(); +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskType(JNIEnv *env, + jclass) { + auto task_spec = + ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentTask(); RAY_CHECK(task_spec) << "Current task is not set."; return static_cast(task_spec->GetMessage().type()); } JNIEXPORT jobject JNICALL -Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskId( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) { +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskId(JNIEnv *env, + jclass) { const ray::TaskID &task_id = - GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetCurrentTaskID(); + ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentTaskID(); return IdToJavaByteBuffer(env, task_id); } JNIEXPORT jobject JNICALL -Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentJobId( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) { +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentJobId(JNIEnv *env, + jclass) { const auto &job_id = - GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetCurrentJobID(); + ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentJobID(); return IdToJavaByteBuffer(env, job_id); } JNIEXPORT jobject JNICALL -Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentWorkerId( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) { +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentWorkerId(JNIEnv *env, + jclass) { const auto &worker_id = - GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetWorkerID(); + ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetWorkerID(); return IdToJavaByteBuffer(env, worker_id); } JNIEXPORT jobject JNICALL -Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) { +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId(JNIEnv *env, + jclass) { const auto &actor_id = - GetWorkerContextFromPointer(nativeCoreWorkerPointer).GetCurrentActorID(); + ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID(); return IdToJavaByteBuffer(env, actor_id); } diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.h b/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.h index db89e2703..983ff1dcb 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_context_NativeWorkerContext.h @@ -24,47 +24,45 @@ extern "C" { /* * Class: org_ray_runtime_context_NativeWorkerContext * Method: nativeGetCurrentTaskType - * Signature: (J)I + * Signature: ()I */ JNIEXPORT jint JNICALL Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskType(JNIEnv *, - jclass, jlong); + jclass); /* * Class: org_ray_runtime_context_NativeWorkerContext * Method: nativeGetCurrentTaskId - * Signature: (J)Ljava/nio/ByteBuffer; + * Signature: ()Ljava/nio/ByteBuffer; */ JNIEXPORT jobject JNICALL -Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskId(JNIEnv *, jclass, - jlong); +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentTaskId(JNIEnv *, jclass); /* * Class: org_ray_runtime_context_NativeWorkerContext * Method: nativeGetCurrentJobId - * Signature: (J)Ljava/nio/ByteBuffer; + * Signature: ()Ljava/nio/ByteBuffer; */ JNIEXPORT jobject JNICALL -Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentJobId(JNIEnv *, jclass, - jlong); +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentJobId(JNIEnv *, jclass); /* * Class: org_ray_runtime_context_NativeWorkerContext * Method: nativeGetCurrentWorkerId - * Signature: (J)Ljava/nio/ByteBuffer; + * Signature: ()Ljava/nio/ByteBuffer; */ JNIEXPORT jobject JNICALL Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentWorkerId(JNIEnv *, - jclass, jlong); + jclass); /* * Class: org_ray_runtime_context_NativeWorkerContext * Method: nativeGetCurrentActorId - * Signature: (J)Ljava/nio/ByteBuffer; + * Signature: ()Ljava/nio/ByteBuffer; */ JNIEXPORT jobject JNICALL -Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId(JNIEnv *, jclass, - jlong); +Java_org_ray_runtime_context_NativeWorkerContext_nativeGetCurrentActorId(JNIEnv *, + jclass); #ifdef __cplusplus } diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_object_NativeObjectStore.cc b/src/ray/core_worker/lib/java/org_ray_runtime_object_NativeObjectStore.cc index ba4fdab5d..a4a175c0b 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_object_NativeObjectStore.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_object_NativeObjectStore.cc @@ -24,55 +24,51 @@ extern "C" { #endif JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_object_NativeObjectStore_nativePut__JLorg_ray_runtime_object_NativeRayObject_2( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jobject obj) { +Java_org_ray_runtime_object_NativeObjectStore_nativePut__Lorg_ray_runtime_object_NativeRayObject_2( + JNIEnv *env, jclass, jobject obj) { auto ray_object = JavaNativeRayObjectToNativeRayObject(env, obj); RAY_CHECK(ray_object != nullptr); ray::ObjectID object_id; - auto status = reinterpret_cast(nativeCoreWorkerPointer) - ->Put(*ray_object, {}, &object_id); + auto status = ray::CoreWorkerProcess::GetCoreWorker().Put(*ray_object, {}, &object_id); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); return IdToJavaByteArray(env, object_id); } JNIEXPORT void JNICALL -Java_org_ray_runtime_object_NativeObjectStore_nativePut__J_3BLorg_ray_runtime_object_NativeRayObject_2( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray objectId, - jobject obj) { +Java_org_ray_runtime_object_NativeObjectStore_nativePut___3BLorg_ray_runtime_object_NativeRayObject_2( + JNIEnv *env, jclass, jbyteArray objectId, jobject obj) { auto object_id = JavaByteArrayToId(env, objectId); auto ray_object = JavaNativeRayObjectToNativeRayObject(env, obj); RAY_CHECK(ray_object != nullptr); - auto status = reinterpret_cast(nativeCoreWorkerPointer) - ->Put(*ray_object, {}, object_id); + auto status = ray::CoreWorkerProcess::GetCoreWorker().Put(*ray_object, {}, object_id); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeGet( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jobject ids, jlong timeoutMs) { + JNIEnv *env, jclass, jobject ids, jlong timeoutMs) { std::vector object_ids; JavaListToNativeVector( env, ids, &object_ids, [](JNIEnv *env, jobject id) { return JavaByteArrayToId(env, static_cast(id)); }); std::vector> results; - auto status = reinterpret_cast(nativeCoreWorkerPointer) - ->Get(object_ids, (int64_t)timeoutMs, &results); + auto status = ray::CoreWorkerProcess::GetCoreWorker().Get(object_ids, + (int64_t)timeoutMs, &results); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); return NativeVectorToJavaList>( env, results, NativeRayObjectToJavaNativeRayObject); } JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeWait( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jobject objectIds, - jint numObjects, jlong timeoutMs) { + JNIEnv *env, jclass, jobject objectIds, jint numObjects, jlong timeoutMs) { std::vector object_ids; JavaListToNativeVector( env, objectIds, &object_ids, [](JNIEnv *env, jobject id) { return JavaByteArrayToId(env, static_cast(id)); }); std::vector results; - auto status = reinterpret_cast(nativeCoreWorkerPointer) - ->Wait(object_ids, (int)numObjects, (int64_t)timeoutMs, &results); + auto status = ray::CoreWorkerProcess::GetCoreWorker().Wait( + object_ids, (int)numObjects, (int64_t)timeoutMs, &results); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); return NativeVectorToJavaList(env, results, [](JNIEnv *env, const bool &item) { return env->NewObject(java_boolean_class, java_boolean_init, (jboolean)item); @@ -80,15 +76,15 @@ JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeWa } JNIEXPORT void JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeDelete( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jobject objectIds, - jboolean localOnly, jboolean deleteCreatingTasks) { + JNIEnv *env, jclass, jobject objectIds, jboolean localOnly, + jboolean deleteCreatingTasks) { std::vector object_ids; JavaListToNativeVector( env, objectIds, &object_ids, [](JNIEnv *env, jobject id) { return JavaByteArrayToId(env, static_cast(id)); }); - auto status = reinterpret_cast(nativeCoreWorkerPointer) - ->Delete(object_ids, (bool)localOnly, (bool)deleteCreatingTasks); + auto status = ray::CoreWorkerProcess::GetCoreWorker().Delete( + object_ids, (bool)localOnly, (bool)deleteCreatingTasks); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_object_NativeObjectStore.h b/src/ray/core_worker/lib/java/org_ray_runtime_object_NativeObjectStore.h index 08ee544c2..1cf64a7ba 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_object_NativeObjectStore.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_object_NativeObjectStore.h @@ -24,44 +24,44 @@ extern "C" { /* * Class: org_ray_runtime_object_NativeObjectStore * Method: nativePut - * Signature: (JLorg/ray/runtime/object/NativeRayObject;)[B + * Signature: (Lorg/ray/runtime/object/NativeRayObject;)[B */ JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_object_NativeObjectStore_nativePut__JLorg_ray_runtime_object_NativeRayObject_2( - JNIEnv *, jclass, jlong, jobject); +Java_org_ray_runtime_object_NativeObjectStore_nativePut__Lorg_ray_runtime_object_NativeRayObject_2( + JNIEnv *, jclass, jobject); /* * Class: org_ray_runtime_object_NativeObjectStore * Method: nativePut - * Signature: (J[BLorg/ray/runtime/object/NativeRayObject;)V + * Signature: ([BLorg/ray/runtime/object/NativeRayObject;)V */ JNIEXPORT void JNICALL -Java_org_ray_runtime_object_NativeObjectStore_nativePut__J_3BLorg_ray_runtime_object_NativeRayObject_2( - JNIEnv *, jclass, jlong, jbyteArray, jobject); +Java_org_ray_runtime_object_NativeObjectStore_nativePut___3BLorg_ray_runtime_object_NativeRayObject_2( + JNIEnv *, jclass, jbyteArray, jobject); /* * Class: org_ray_runtime_object_NativeObjectStore * Method: nativeGet - * Signature: (JLjava/util/List;J)Ljava/util/List; + * Signature: (Ljava/util/List;J)Ljava/util/List; */ -JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeGet( - JNIEnv *, jclass, jlong, jobject, jlong); +JNIEXPORT jobject JNICALL +Java_org_ray_runtime_object_NativeObjectStore_nativeGet(JNIEnv *, jclass, jobject, jlong); /* * Class: org_ray_runtime_object_NativeObjectStore * Method: nativeWait - * Signature: (JLjava/util/List;IJ)Ljava/util/List; + * Signature: (Ljava/util/List;IJ)Ljava/util/List; */ JNIEXPORT jobject JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeWait( - JNIEnv *, jclass, jlong, jobject, jint, jlong); + JNIEnv *, jclass, jobject, jint, jlong); /* * Class: org_ray_runtime_object_NativeObjectStore * Method: nativeDelete - * Signature: (JLjava/util/List;ZZ)V + * Signature: (Ljava/util/List;ZZ)V */ JNIEXPORT void JNICALL Java_org_ray_runtime_object_NativeObjectStore_nativeDelete( - JNIEnv *, jclass, jlong, jobject, jboolean, jboolean); + JNIEnv *, jclass, jobject, jboolean, jboolean); #ifdef __cplusplus } diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.cc b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.cc index 806036d3c..9753793ef 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.cc @@ -27,9 +27,9 @@ extern "C" { using ray::ClientID; JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer) { - auto &core_worker = *reinterpret_cast(nativeCoreWorkerPointer); +Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(JNIEnv *env, + jclass) { + auto &core_worker = ray::CoreWorkerProcess::GetCoreWorker(); const auto &actor_id = core_worker.GetWorkerContext().GetCurrentActorID(); const auto &task_spec = core_worker.GetWorkerContext().GetCurrentTask(); RAY_CHECK(task_spec->IsActorTask()); @@ -44,11 +44,12 @@ Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint( JNIEXPORT void JNICALL Java_org_ray_runtime_task_NativeTaskExecutor_nativeNotifyActorResumedFromCheckpoint( - JNIEnv *env, jclass, jlong nativeCoreWorkerPointer, jbyteArray checkpointId) { - auto &core_worker = *reinterpret_cast(nativeCoreWorkerPointer); - const auto &actor_id = core_worker.GetWorkerContext().GetCurrentActorID(); + JNIEnv *env, jclass, jbyteArray checkpointId) { + const auto &actor_id = + ray::CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID(); const auto checkpoint_id = JavaByteArrayToId(env, checkpointId); - auto status = core_worker.NotifyActorResumedFromCheckpoint(actor_id, checkpoint_id); + auto status = ray::CoreWorkerProcess::GetCoreWorker().NotifyActorResumedFromCheckpoint( + actor_id, checkpoint_id); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, (void)0); } diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.h b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.h index 21cb99a84..87b0f1837 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskExecutor.h @@ -26,20 +26,19 @@ extern "C" { /* * Class: org_ray_runtime_task_NativeTaskExecutor * Method: nativePrepareCheckpoint - * Signature: (J)[B + * Signature: ()[B */ JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(JNIEnv *, jclass, - jlong); +Java_org_ray_runtime_task_NativeTaskExecutor_nativePrepareCheckpoint(JNIEnv *, jclass); /* * Class: org_ray_runtime_task_NativeTaskExecutor * Method: nativeNotifyActorResumedFromCheckpoint - * Signature: (J[B)V + * Signature: ([B)V */ JNIEXPORT void JNICALL Java_org_ray_runtime_task_NativeTaskExecutor_nativeNotifyActorResumedFromCheckpoint( - JNIEnv *, jclass, jlong, jbyteArray); + JNIEnv *, jclass, jbyteArray); #ifdef __cplusplus } diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc index 7116b6684..9a1e26a75 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc @@ -19,10 +19,6 @@ #include "ray/core_worker/core_worker.h" #include "ray/core_worker/lib/java/jni_utils.h" -inline ray::CoreWorker &GetCoreWorker(jlong nativeCoreWorkerPointer) { - return *reinterpret_cast(nativeCoreWorkerPointer); -} - inline ray::RayFunction ToRayFunction(JNIEnv *env, jobject functionDescriptor) { std::vector function_descriptor_list; jobject list = @@ -127,17 +123,17 @@ extern "C" { #endif JNIEXPORT jobject JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSubmitTask( - JNIEnv *env, jclass p, jlong nativeCoreWorkerPointer, jobject functionDescriptor, - jobject args, jint numReturns, jobject callOptions) { + JNIEnv *env, jclass p, jobject functionDescriptor, jobject args, jint numReturns, + jobject callOptions) { auto ray_function = ToRayFunction(env, functionDescriptor); auto task_args = ToTaskArgs(env, args); auto task_options = ToTaskOptions(env, numReturns, callOptions); std::vector return_ids; // TODO (kfstorm): Allow setting `max_retries` via `CallOptions`. - auto status = GetCoreWorker(nativeCoreWorkerPointer) - .SubmitTask(ray_function, task_args, task_options, &return_ids, - /*max_retries=*/0); + auto status = ray::CoreWorkerProcess::GetCoreWorker().SubmitTask( + ray_function, task_args, task_options, &return_ids, + /*max_retries=*/0); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); @@ -146,16 +142,16 @@ JNIEXPORT jobject JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSu JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor( - JNIEnv *env, jclass p, jlong nativeCoreWorkerPointer, jobject functionDescriptor, - jobject args, jobject actorCreationOptions) { + JNIEnv *env, jclass p, jobject functionDescriptor, jobject args, + jobject actorCreationOptions) { auto ray_function = ToRayFunction(env, functionDescriptor); auto task_args = ToTaskArgs(env, args); auto actor_creation_options = ToActorCreationOptions(env, actorCreationOptions); - ray::ActorID actor_id; - auto status = GetCoreWorker(nativeCoreWorkerPointer) - .CreateActor(ray_function, task_args, actor_creation_options, - /*extension_data*/ "", &actor_id); + ActorID actor_id; + auto status = ray::CoreWorkerProcess::GetCoreWorker().CreateActor( + ray_function, task_args, actor_creation_options, + /*extension_data*/ "", &actor_id); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); return IdToJavaByteArray(env, actor_id); @@ -163,17 +159,16 @@ Java_org_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor( JNIEXPORT jobject JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask( - JNIEnv *env, jclass p, jlong nativeCoreWorkerPointer, jbyteArray actorId, - jobject functionDescriptor, jobject args, jint numReturns, jobject callOptions) { + JNIEnv *env, jclass p, jbyteArray actorId, jobject functionDescriptor, jobject args, + jint numReturns, jobject callOptions) { auto actor_id = JavaByteArrayToId(env, actorId); auto ray_function = ToRayFunction(env, functionDescriptor); auto task_args = ToTaskArgs(env, args); auto task_options = ToTaskOptions(env, numReturns, callOptions); std::vector return_ids; - auto status = - GetCoreWorker(nativeCoreWorkerPointer) - .SubmitActorTask(actor_id, ray_function, task_args, task_options, &return_ids); + auto status = ray::CoreWorkerProcess::GetCoreWorker().SubmitActorTask( + actor_id, ray_function, task_args, task_options, &return_ids); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); return NativeIdVectorToJavaByteArrayList(env, return_ids); diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.h b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.h index 13f7f8c3c..472f2aaf7 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.h @@ -25,33 +25,32 @@ extern "C" { * Class: org_ray_runtime_task_NativeTaskSubmitter * Method: nativeSubmitTask * Signature: - * (JLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILorg/ray/api/options/CallOptions;)Ljava/util/List; + * (Lorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILorg/ray/api/options/CallOptions;)Ljava/util/List; */ JNIEXPORT jobject JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSubmitTask( - JNIEnv *, jclass, jlong, jobject, jobject, jint, jobject); + JNIEnv *, jclass, jobject, jobject, jint, jobject); /* * Class: org_ray_runtime_task_NativeTaskSubmitter * Method: nativeCreateActor * Signature: - * (JLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;Lorg/ray/api/options/ActorCreationOptions;)[B + * (Lorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;Lorg/ray/api/options/ActorCreationOptions;)[B */ JNIEXPORT jbyteArray JNICALL -Java_org_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor(JNIEnv *, jclass, jlong, - jobject, jobject, - jobject); +Java_org_ray_runtime_task_NativeTaskSubmitter_nativeCreateActor(JNIEnv *, jclass, jobject, + jobject, jobject); /* * Class: org_ray_runtime_task_NativeTaskSubmitter * Method: nativeSubmitActorTask * Signature: - * (J[BLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILorg/ray/api/options/CallOptions;)Ljava/util/List; + * ([BLorg/ray/runtime/functionmanager/FunctionDescriptor;Ljava/util/List;ILorg/ray/api/options/CallOptions;)Ljava/util/List; */ JNIEXPORT jobject JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask(JNIEnv *, jclass, - jlong, jbyteArray, - jobject, jobject, - jint, jobject); + jbyteArray, jobject, + jobject, jint, + jobject); #ifdef __cplusplus } diff --git a/src/ray/core_worker/reference_count_test.cc b/src/ray/core_worker/reference_count_test.cc index 0e275444e..ef5da8ed5 100644 --- a/src/ray/core_worker/reference_count_test.cc +++ b/src/ray/core_worker/reference_count_test.cc @@ -340,7 +340,7 @@ TEST(MemoryStoreIntegrationTest, TestSimple) { RAY_CHECK(store.Put(buffer, id1)); ASSERT_EQ(store.Size(), 1); std::vector> results; - WorkerContext ctx(WorkerType::WORKER, JobID::Nil()); + WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); RAY_CHECK_OK(store.Get({id1}, /*num_objects*/ 1, /*timeout_ms*/ -1, ctx, /*remove_after_get*/ true, &results)); ASSERT_EQ(results.size(), 1); diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index b5f79ef54..3aba27809 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -57,8 +57,7 @@ static void flushall_redis(void) { redisFree(context); } -ActorID CreateActorHelper(CoreWorker &worker, - std::unordered_map &resources, +ActorID CreateActorHelper(std::unordered_map &resources, uint64_t max_reconstructions) { std::unique_ptr actor_handle; @@ -78,8 +77,8 @@ ActorID CreateActorHelper(CoreWorker &worker, // Create an actor. ActorID actor_id; - RAY_CHECK_OK( - worker.CreateActor(func, args, actor_options, /*extension_data*/ "", &actor_id)); + RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().CreateActor( + func, args, actor_options, /*extension_data*/ "", &actor_id)); return actor_id; } @@ -90,7 +89,8 @@ std::string MetadataToString(std::shared_ptr obj) { class CoreWorkerTest : public ::testing::Test { public: - CoreWorkerTest(int num_nodes) : gcs_options_("127.0.0.1", 6379, "") { + CoreWorkerTest(int num_nodes) + : num_nodes_(num_nodes), gcs_options_("127.0.0.1", 6379, "") { #ifdef _WIN32 RAY_CHECK(false) << "port system() calls to Windows before running this test"; #endif @@ -250,9 +250,39 @@ class CoreWorkerTest : public ::testing::Test { ASSERT_TRUE(system(("rm -f " + gcs_server_pid).c_str()) == 0); } - void SetUp() {} + void SetUp() { + if (num_nodes_ > 0) { + CoreWorkerOptions options = { + WorkerType::DRIVER, // worker_type + Language::PYTHON, // langauge + raylet_store_socket_names_[0], // store_socket + raylet_socket_names_[0], // raylet_socket + NextJobId(), // job_id + gcs_options_, // gcs_options + "", // log_dir + true, // install_failure_signal_handler + "127.0.0.1", // node_ip_address + node_manager_port, // node_manager_port + "core_worker_test", // driver_name + "", // stdout_file + "", // stderr_file + nullptr, // task_execution_callback + nullptr, // check_signals + nullptr, // gc_collect + nullptr, // get_lang_stack + true, // ref_counting_enabled + false, // is_local_mode + 1, // num_workers + }; + CoreWorkerProcess::Initialize(options); + } + } - void TearDown() {} + void TearDown() { + if (num_nodes_ > 0) { + CoreWorkerProcess::Shutdown(); + } + } // Test normal tasks. void TestNormalTask(std::unordered_map &resources); @@ -271,13 +301,14 @@ class CoreWorkerTest : public ::testing::Test { void TestActorReconstruction(std::unordered_map &resources); protected: - bool WaitForDirectCallActorState(CoreWorker &worker, const ActorID &actor_id, - bool wait_alive, int timeout_ms); + bool WaitForDirectCallActorState(const ActorID &actor_id, bool wait_alive, + int timeout_ms); // Get the pid for the worker process that runs the actor. - int GetActorPid(CoreWorker &worker, const ActorID &actor_id, + int GetActorPid(const ActorID &actor_id, std::unordered_map &resources); + int num_nodes_; std::vector raylet_socket_names_; std::vector raylet_store_socket_names_; std::string raylet_monitor_pid_; @@ -285,18 +316,19 @@ class CoreWorkerTest : public ::testing::Test { std::string gcs_server_pid_; }; -bool CoreWorkerTest::WaitForDirectCallActorState(CoreWorker &worker, - const ActorID &actor_id, bool wait_alive, +bool CoreWorkerTest::WaitForDirectCallActorState(const ActorID &actor_id, bool wait_alive, int timeout_ms) { - auto condition_func = [&worker, actor_id, wait_alive]() -> bool { - bool actor_alive = worker.direct_actor_submitter_->IsActorAlive(actor_id); + auto condition_func = [actor_id, wait_alive]() -> bool { + bool actor_alive = + CoreWorkerProcess::GetCoreWorker().direct_actor_submitter_->IsActorAlive( + actor_id); return wait_alive ? actor_alive : !actor_alive; }; return WaitForCondition(condition_func, timeout_ms); } -int CoreWorkerTest::GetActorPid(CoreWorker &worker, const ActorID &actor_id, +int CoreWorkerTest::GetActorPid(const ActorID &actor_id, std::unordered_map &resources) { std::vector args; TaskOptions options{1, resources}; @@ -304,10 +336,11 @@ int CoreWorkerTest::GetActorPid(CoreWorker &worker, const ActorID &actor_id, RayFunction func{Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( "GetWorkerPid", "", "", "")}; - RAY_CHECK_OK(worker.SubmitActorTask(actor_id, func, args, options, &return_ids)); + RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().SubmitActorTask(actor_id, func, args, + options, &return_ids)); std::vector> results; - RAY_CHECK_OK(worker.Get(return_ids, -1, &results)); + RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().Get(return_ids, -1, &results)); if (nullptr == results[0]->GetData()) { // If failed to get actor process pid, return -1 @@ -320,9 +353,7 @@ int CoreWorkerTest::GetActorPid(CoreWorker &worker, const ActorID &actor_id, } void CoreWorkerTest::TestNormalTask(std::unordered_map &resources) { - CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], - raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1", - node_manager_port, nullptr); + auto &driver = CoreWorkerProcess::GetCoreWorker(); // Test for tasks with by-value and by-ref args. { @@ -364,11 +395,9 @@ void CoreWorkerTest::TestNormalTask(std::unordered_map &res } void CoreWorkerTest::TestActorTask(std::unordered_map &resources) { - CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], - raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1", - node_manager_port, nullptr); + auto &driver = CoreWorkerProcess::GetCoreWorker(); - auto actor_id = CreateActorHelper(driver, resources, 1000); + auto actor_id = CreateActorHelper(resources, 1000); // Test submitting some tasks with by-value args for that actor. { @@ -452,18 +481,16 @@ void CoreWorkerTest::TestActorTask(std::unordered_map &reso void CoreWorkerTest::TestActorReconstruction( std::unordered_map &resources) { - CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], - raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1", - node_manager_port, nullptr); + auto &driver = CoreWorkerProcess::GetCoreWorker(); // creating actor. - auto actor_id = CreateActorHelper(driver, resources, 1000); + auto actor_id = CreateActorHelper(resources, 1000); // Wait for actor alive event. - ASSERT_TRUE(WaitForDirectCallActorState(driver, actor_id, true, 30 * 1000 /* 30s */)); + ASSERT_TRUE(WaitForDirectCallActorState(actor_id, true, 30 * 1000 /* 30s */)); RAY_LOG(INFO) << "actor has been created"; - auto pid = GetActorPid(driver, actor_id, resources); + auto pid = GetActorPid(actor_id, resources); RAY_CHECK(pid != -1); // Test submitting some tasks with by-value args for that actor. @@ -477,9 +504,8 @@ void CoreWorkerTest::TestActorReconstruction( ASSERT_EQ(system("pkill mock_worker"), 0); // Wait for actor restruction event, and then for alive event. - auto check_actor_restart_func = [this, pid, &driver, &actor_id, - &resources]() -> bool { - auto new_pid = GetActorPid(driver, actor_id, resources); + auto check_actor_restart_func = [this, pid, &actor_id, &resources]() -> bool { + auto new_pid = GetActorPid(actor_id, resources); return new_pid != -1 && new_pid != pid; }; ASSERT_TRUE(WaitForCondition(check_actor_restart_func, 30 * 1000 /* 30s */)); @@ -514,12 +540,10 @@ void CoreWorkerTest::TestActorReconstruction( void CoreWorkerTest::TestActorFailure( std::unordered_map &resources) { - CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], - raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1", - node_manager_port, nullptr); + auto &driver = CoreWorkerProcess::GetCoreWorker(); // creating actor. - auto actor_id = CreateActorHelper(driver, resources, 0 /* not reconstructable */); + auto actor_id = CreateActorHelper(resources, 0 /* not reconstructable */); // Test submitting some tasks with by-value args for that actor. { @@ -666,16 +690,14 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { } TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) { - CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], - raylet_socket_names_[0], JobID::FromInt(1), gcs_options_, "", - "127.0.0.1", node_manager_port, nullptr); + auto &driver = CoreWorkerProcess::GetCoreWorker(); std::vector object_ids; // Create an actor. std::unordered_map resources; - auto actor_id = CreateActorHelper(driver, resources, + auto actor_id = CreateActorHelper(resources, /*max_reconstructions=*/0); // wait for actor creation finish. - ASSERT_TRUE(WaitForDirectCallActorState(driver, actor_id, true, 30 * 1000 /* 30s */)); + ASSERT_TRUE(WaitForDirectCallActorState(actor_id, true, 30 * 1000 /* 30s */)); // Test submitting some tasks with by-value args for that actor. int64_t start_ms = current_time_ms(); const int num_tasks = 100000; @@ -713,7 +735,7 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) { TEST_F(ZeroNodeTest, TestWorkerContext) { auto job_id = NextJobId(); - WorkerContext context(WorkerType::WORKER, job_id); + WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), job_id); ASSERT_TRUE(context.GetCurrentTaskID().IsNil()); ASSERT_EQ(context.GetNextTaskIndex(), 1); ASSERT_EQ(context.GetNextTaskIndex(), 2); @@ -779,7 +801,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { absl::flat_hash_set wait_results; ObjectID nonexistent_id = ObjectID::FromRandom().WithDirectTransportType(); - WorkerContext ctx(WorkerType::WORKER, JobID::Nil()); + WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); wait_ids.insert(nonexistent_id); RAY_CHECK_OK(provider.Wait(wait_ids, ids.size() + 1, 100, ctx, &wait_results)); ASSERT_EQ(wait_results.size(), ids.size()); @@ -880,10 +902,7 @@ TEST_F(SingleNodeTest, TestMemoryStoreProvider) { } TEST_F(SingleNodeTest, TestObjectInterface) { - CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON, - raylet_store_socket_names_[0], raylet_socket_names_[0], - JobID::FromInt(1), gcs_options_, "", "127.0.0.1", - node_manager_port, nullptr); + auto &core_worker = CoreWorkerProcess::GetCoreWorker(); uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; uint8_t array2[] = {10, 11, 12, 13, 14, 15}; diff --git a/src/ray/core_worker/test/direct_actor_transport_test.cc b/src/ray/core_worker/test/direct_actor_transport_test.cc index c68d84d08..50031c74a 100644 --- a/src/ray/core_worker/test/direct_actor_transport_test.cc +++ b/src/ray/core_worker/test/direct_actor_transport_test.cc @@ -233,7 +233,7 @@ rpc::PushTaskRequest CreatePushTaskRequestHelper(ActorID actor_id, int64_t count class MockWorkerContext : public WorkerContext { public: MockWorkerContext(WorkerType worker_type, const JobID &job_id) - : WorkerContext(worker_type, job_id) { + : WorkerContext(worker_type, WorkerID::FromRandom(), job_id) { current_actor_is_direct_call_ = true; } }; diff --git a/src/ray/core_worker/test/mock_worker.cc b/src/ray/core_worker/test/mock_worker.cc index 580ce99cd..8c28b87ed 100644 --- a/src/ray/core_worker/test/mock_worker.cc +++ b/src/ray/core_worker/test/mock_worker.cc @@ -33,13 +33,34 @@ namespace ray { class MockWorker { public: MockWorker(const std::string &store_socket, const std::string &raylet_socket, - int node_manager_port, const gcs::GcsClientOptions &gcs_options) - : worker_(WorkerType::WORKER, Language::PYTHON, store_socket, raylet_socket, - JobID::FromInt(1), gcs_options, /*log_dir=*/"", - /*node_id_address=*/"127.0.0.1", node_manager_port, - std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7)) {} + int node_manager_port, const gcs::GcsClientOptions &gcs_options) { + CoreWorkerOptions options = { + WorkerType::WORKER, // worker_type + Language::PYTHON, // langauge + store_socket, // store_socket + raylet_socket, // raylet_socket + JobID::FromInt(1), // job_id + gcs_options, // gcs_options + "", // log_dir + true, // install_failure_signal_handler + "127.0.0.1", // node_ip_address + node_manager_port, // node_manager_port + "", // driver_name + "", // stdout_file + "", // stderr_file + std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, + _7), // task_execution_callback + nullptr, // check_signals + nullptr, // gc_collect + nullptr, // get_lang_stack + true, // ref_counting_enabled + false, // is_local_mode + 1, // num_workers + }; + CoreWorkerProcess::Initialize(options); + } - void StartExecutingTasks() { worker_.StartExecutingTasks(); } + void RunTaskExecutionLoop() { CoreWorkerProcess::RunTaskExecutionLoop(); } private: Status ExecuteTask(TaskType task_type, const RayFunction &ray_function, @@ -112,7 +133,6 @@ class MockWorker { return Status::OK(); } - CoreWorker worker_; int64_t prev_seq_no_ = 0; }; @@ -126,6 +146,6 @@ int main(int argc, char **argv) { ray::gcs::GcsClientOptions gcs_options("127.0.0.1", 6379, ""); ray::MockWorker worker(store_socket, raylet_socket, node_manager_port, gcs_options); - worker.StartExecutingTasks(); + worker.RunTaskExecutionLoop(); return 0; } diff --git a/src/ray/core_worker/test/scheduling_queue_test.cc b/src/ray/core_worker/test/scheduling_queue_test.cc index 48b882073..322bc79d2 100644 --- a/src/ray/core_worker/test/scheduling_queue_test.cc +++ b/src/ray/core_worker/test/scheduling_queue_test.cc @@ -37,7 +37,7 @@ class MockWaiter : public DependencyWaiter { TEST(SchedulingQueueTest, TestInOrder) { boost::asio::io_service io_service; MockWaiter waiter; - WorkerContext context(WorkerType::WORKER, JobID::Nil()); + WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); SchedulingQueue queue(io_service, waiter, context); int n_ok = 0; int n_rej = 0; @@ -58,7 +58,7 @@ TEST(SchedulingQueueTest, TestWaitForObjects) { ObjectID obj3 = ObjectID::FromRandom(); boost::asio::io_service io_service; MockWaiter waiter; - WorkerContext context(WorkerType::WORKER, JobID::Nil()); + WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); SchedulingQueue queue(io_service, waiter, context); int n_ok = 0; int n_rej = 0; @@ -84,7 +84,7 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) { ObjectID obj1 = ObjectID::FromRandom(); boost::asio::io_service io_service; MockWaiter waiter; - WorkerContext context(WorkerType::WORKER, JobID::Nil()); + WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); SchedulingQueue queue(io_service, waiter, context); int n_ok = 0; int n_rej = 0; @@ -102,7 +102,7 @@ TEST(SchedulingQueueTest, TestWaitForObjectsNotSubjectToSeqTimeout) { TEST(SchedulingQueueTest, TestOutOfOrder) { boost::asio::io_service io_service; MockWaiter waiter; - WorkerContext context(WorkerType::WORKER, JobID::Nil()); + WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); SchedulingQueue queue(io_service, waiter, context); int n_ok = 0; int n_rej = 0; @@ -120,7 +120,7 @@ TEST(SchedulingQueueTest, TestOutOfOrder) { TEST(SchedulingQueueTest, TestSeqWaitTimeout) { boost::asio::io_service io_service; MockWaiter waiter; - WorkerContext context(WorkerType::WORKER, JobID::Nil()); + WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); SchedulingQueue queue(io_service, waiter, context); int n_ok = 0; int n_rej = 0; @@ -143,7 +143,7 @@ TEST(SchedulingQueueTest, TestSeqWaitTimeout) { TEST(SchedulingQueueTest, TestSkipAlreadyProcessedByClient) { boost::asio::io_service io_service; MockWaiter waiter; - WorkerContext context(WorkerType::WORKER, JobID::Nil()); + WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::Nil()); SchedulingQueue queue(io_service, waiter, context); int n_ok = 0; int n_rej = 0; diff --git a/src/ray/core_worker/test/task_manager_test.cc b/src/ray/core_worker/test/task_manager_test.cc index e9efc9ac7..74e98e8fb 100644 --- a/src/ray/core_worker/test/task_manager_test.cc +++ b/src/ray/core_worker/test/task_manager_test.cc @@ -79,7 +79,7 @@ TEST_F(TaskManagerTest, TestTaskSuccess) { ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId())); ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3); auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT); - WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0)); + WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); rpc::PushTaskReply reply; auto return_object = reply.add_return_objects(); @@ -119,7 +119,7 @@ TEST_F(TaskManagerTest, TestTaskFailure) { ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId())); ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3); auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT); - WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0)); + WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); auto error = rpc::ErrorType::WORKER_DIED; manager_.PendingTaskFailed(spec.TaskId(), error); @@ -155,7 +155,7 @@ TEST_F(TaskManagerTest, TestTaskRetry) { ASSERT_TRUE(manager_.IsTaskPending(spec.TaskId())); ASSERT_EQ(reference_counter_->NumObjectIDsInScope(), 3); auto return_id = spec.ReturnId(0, TaskTransportType::DIRECT); - WorkerContext ctx(WorkerType::WORKER, JobID::FromInt(0)); + WorkerContext ctx(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0)); auto error = rpc::ErrorType::WORKER_DIED; for (int i = 0; i < num_retries; i++) { diff --git a/src/ray/gcs/accessor.h b/src/ray/gcs/accessor.h index 5459d1707..3cdf30598 100644 --- a/src/ray/gcs/accessor.h +++ b/src/ray/gcs/accessor.h @@ -568,6 +568,17 @@ class WorkerInfoAccessor { const std::shared_ptr &data_ptr, const StatusCallback &callback) = 0; + /// Register a worker to GCS asynchronously. + /// + /// \param worker_type The type of the worker. + /// \param worker_id The ID of the worker. + /// \param worker_info The information of the worker. + /// \return Status. + virtual Status AsyncRegisterWorker( + rpc::WorkerType worker_type, const WorkerID &worker_id, + const std::unordered_map &worker_info, + const StatusCallback &callback) = 0; + protected: WorkerInfoAccessor() = default; }; diff --git a/src/ray/gcs/gcs_client.h b/src/ray/gcs/gcs_client.h index a1281ec7c..e61ea035c 100644 --- a/src/ray/gcs/gcs_client.h +++ b/src/ray/gcs/gcs_client.h @@ -45,6 +45,8 @@ class GcsClientOptions { password_(password), is_test_client_(is_test_client) {} + GcsClientOptions() {} + // GCS server address std::string server_ip_; int server_port_; diff --git a/src/ray/gcs/gcs_client/service_based_accessor.cc b/src/ray/gcs/gcs_client/service_based_accessor.cc index ebf66fbcf..b3dada08b 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.cc +++ b/src/ray/gcs/gcs_client/service_based_accessor.cc @@ -875,5 +875,25 @@ Status ServiceBasedWorkerInfoAccessor::AsyncReportWorkerFailure( return Status::OK(); } +Status ServiceBasedWorkerInfoAccessor::AsyncRegisterWorker( + rpc::WorkerType worker_type, const WorkerID &worker_id, + const std::unordered_map &worker_info, + const StatusCallback &callback) { + RAY_LOG(DEBUG) << "Registering the worker. worker id = " << worker_id; + rpc::RegisterWorkerRequest request; + request.set_worker_type(worker_type); + request.set_worker_id(worker_id.Binary()); + request.mutable_worker_info()->insert(worker_info.begin(), worker_info.end()); + client_impl_->GetGcsRpcClient().RegisterWorker( + request, + [worker_id, callback](const Status &status, const rpc::RegisterWorkerReply &reply) { + if (callback) { + callback(status); + } + RAY_LOG(DEBUG) << "Finished registering worker. worker id = " << worker_id; + }); + return Status::OK(); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/gcs_client/service_based_accessor.h b/src/ray/gcs/gcs_client/service_based_accessor.h index c428b5054..714327949 100644 --- a/src/ray/gcs/gcs_client/service_based_accessor.h +++ b/src/ray/gcs/gcs_client/service_based_accessor.h @@ -329,6 +329,11 @@ class ServiceBasedWorkerInfoAccessor : public WorkerInfoAccessor { Status AsyncReportWorkerFailure(const std::shared_ptr &data_ptr, const StatusCallback &callback) override; + Status AsyncRegisterWorker( + rpc::WorkerType worker_type, const WorkerID &worker_id, + const std::unordered_map &worker_info, + const StatusCallback &callback) override; + private: ServiceBasedGcsClient *client_impl_; diff --git a/src/ray/gcs/gcs_server/worker_info_handler_impl.cc b/src/ray/gcs/gcs_server/worker_info_handler_impl.cc index 05bcb84c9..bfe148503 100644 --- a/src/ray/gcs/gcs_server/worker_info_handler_impl.cc +++ b/src/ray/gcs/gcs_server/worker_info_handler_impl.cc @@ -40,5 +40,27 @@ void DefaultWorkerInfoHandler::HandleReportWorkerFailure( RAY_LOG(DEBUG) << "Finished reporting worker failure, " << worker_address.DebugString(); } +void DefaultWorkerInfoHandler::HandleRegisterWorker( + const RegisterWorkerRequest &request, RegisterWorkerReply *reply, + SendReplyCallback send_reply_callback) { + auto worker_type = request.worker_type(); + auto worker_id = WorkerID::FromBinary(request.worker_id()); + auto worker_info = MapFromProtobuf(request.worker_info()); + + auto on_done = [worker_id, reply, send_reply_callback](Status status) { + if (!status.ok()) { + RAY_LOG(ERROR) << "Failed to register worker " << worker_id; + } + GCS_RPC_SEND_REPLY(send_reply_callback, reply, Status::OK()); + }; + + Status status = gcs_client_.Workers().AsyncRegisterWorker(worker_type, worker_id, + worker_info, on_done); + if (!status.ok()) { + on_done(status); + } + RAY_LOG(DEBUG) << "Finished registering worker " << worker_id; +} + } // namespace rpc } // namespace ray diff --git a/src/ray/gcs/gcs_server/worker_info_handler_impl.h b/src/ray/gcs/gcs_server/worker_info_handler_impl.h index 35b61ab4f..2929c4ecc 100644 --- a/src/ray/gcs/gcs_server/worker_info_handler_impl.h +++ b/src/ray/gcs/gcs_server/worker_info_handler_impl.h @@ -31,6 +31,10 @@ class DefaultWorkerInfoHandler : public rpc::WorkerInfoHandler { ReportWorkerFailureReply *reply, SendReplyCallback send_reply_callback) override; + void HandleRegisterWorker(const RegisterWorkerRequest &request, + RegisterWorkerReply *reply, + SendReplyCallback send_reply_callback) override; + private: gcs::RedisGcsClient &gcs_client_; }; diff --git a/src/ray/gcs/redis_accessor.cc b/src/ray/gcs/redis_accessor.cc index c9e9c993e..9a1ff4445 100644 --- a/src/ray/gcs/redis_accessor.cc +++ b/src/ray/gcs/redis_accessor.cc @@ -747,6 +747,30 @@ Status RedisWorkerInfoAccessor::AsyncReportWorkerFailure( return worker_failure_table.Add(JobID::Nil(), worker_id, data_ptr, on_done); } +Status RedisWorkerInfoAccessor::AsyncRegisterWorker( + rpc::WorkerType worker_type, const WorkerID &worker_id, + const std::unordered_map &worker_info, + const StatusCallback &callback) { + std::vector args; + args.emplace_back("HMSET"); + if (worker_type == rpc::WorkerType::DRIVER) { + args.emplace_back("Drivers:" + worker_id.Binary()); + } else { + args.emplace_back("Workers:" + worker_id.Binary()); + } + for (const auto &entry : worker_info) { + args.push_back(entry.first); + args.push_back(entry.second); + } + + auto status = client_impl_->primary_context()->RunArgvAsync(args); + if (callback) { + // TODO (kfstorm): Invoke the callback asynchronously. + callback(status); + } + return status; +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/redis_accessor.h b/src/ray/gcs/redis_accessor.h index e2ba41f36..50cdeabc1 100644 --- a/src/ray/gcs/redis_accessor.h +++ b/src/ray/gcs/redis_accessor.h @@ -397,6 +397,11 @@ class RedisWorkerInfoAccessor : public WorkerInfoAccessor { Status AsyncReportWorkerFailure(const std::shared_ptr &data_ptr, const StatusCallback &callback) override; + Status AsyncRegisterWorker( + rpc::WorkerType worker_type, const WorkerID &worker_id, + const std::unordered_map &worker_info, + const StatusCallback &callback) override; + private: RedisGcsClient *client_impl_{nullptr}; diff --git a/src/ray/protobuf/gcs_service.proto b/src/ray/protobuf/gcs_service.proto index afb27e6d6..8f3a5150b 100644 --- a/src/ray/protobuf/gcs_service.proto +++ b/src/ray/protobuf/gcs_service.proto @@ -16,6 +16,7 @@ syntax = "proto3"; package ray.rpc; +import "src/ray/protobuf/common.proto"; import "src/ray/protobuf/gcs.proto"; message GcsStatus { @@ -345,8 +346,23 @@ message ReportWorkerFailureReply { GcsStatus status = 1; } +message RegisterWorkerRequest { + /// The type of the worker. + WorkerType worker_type = 1; + /// The ID of the worker. + bytes worker_id = 2; + /// The information of the worker in a dictionary. + map worker_info = 3; +} + +message RegisterWorkerReply { + GcsStatus status = 1; +} + // Service for worker info access. service WorkerInfoGcsService { // Report a worker failure to GCS Service. rpc ReportWorkerFailure(ReportWorkerFailureRequest) returns (ReportWorkerFailureReply); + // Register a worker to GCS Service. + rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerReply); } diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index 04ea45ee8..df9c6bf98 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -182,6 +182,10 @@ class GcsRpcClient { VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, ReportWorkerFailure, worker_info_grpc_client_, ) + /// Register a worker to GCS Service. + VOID_GCS_RPC_CLIENT_METHOD(WorkerInfoGcsService, RegisterWorker, + worker_info_grpc_client_, ) + private: void Init(const std::string &address, const int port, ClientCallManager &client_call_manager) { diff --git a/src/ray/rpc/gcs_server/gcs_rpc_server.h b/src/ray/rpc/gcs_server/gcs_rpc_server.h index 8e1c84ecb..d12340814 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_server.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_server.h @@ -390,6 +390,10 @@ class WorkerInfoGcsServiceHandler { virtual void HandleReportWorkerFailure(const ReportWorkerFailureRequest &request, ReportWorkerFailureReply *reply, SendReplyCallback send_reply_callback) = 0; + + virtual void HandleRegisterWorker(const RegisterWorkerRequest &request, + RegisterWorkerReply *reply, + SendReplyCallback send_reply_callback) = 0; }; /// The `GrpcService` for `WorkerInfoGcsService`. @@ -409,6 +413,7 @@ class WorkerInfoGrpcService : public GrpcService { const std::unique_ptr &cq, std::vector> *server_call_factories) override { WORKER_INFO_SERVICE_RPC_HANDLER(ReportWorkerFailure); + WORKER_INFO_SERVICE_RPC_HANDLER(RegisterWorker); } private: diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/TransferHandler.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/TransferHandler.java index 2307b64e4..5f17eac4f 100644 --- a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/TransferHandler.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/TransferHandler.java @@ -1,6 +1,5 @@ package org.ray.streaming.runtime.transfer; -import com.google.common.base.Preconditions; import org.ray.runtime.RayNativeRuntime; import org.ray.runtime.functionmanager.FunctionDescriptor; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; @@ -24,16 +23,12 @@ public class TransferHandler { private long writerClientNative; private long readerClientNative; - public TransferHandler(long coreWorkerNative, - JavaFunctionDescriptor writerAsyncFunc, + public TransferHandler(JavaFunctionDescriptor writerAsyncFunc, JavaFunctionDescriptor writerSyncFunc, JavaFunctionDescriptor readerAsyncFunc, JavaFunctionDescriptor readerSyncFunc) { - Preconditions.checkArgument(coreWorkerNative != 0); - writerClientNative = createWriterClientNative( - coreWorkerNative, writerAsyncFunc, writerSyncFunc); - readerClientNative = createReaderClientNative( - coreWorkerNative, readerAsyncFunc, readerSyncFunc); + writerClientNative = createWriterClientNative(writerAsyncFunc, writerSyncFunc); + readerClientNative = createReaderClientNative(readerAsyncFunc, readerSyncFunc); } public void onWriterMessage(byte[] buffer) { @@ -53,12 +48,10 @@ public class TransferHandler { } private native long createWriterClientNative( - long coreWorkerNative, FunctionDescriptor asyncFunc, FunctionDescriptor syncFunc); private native long createReaderClientNative( - long coreWorkerNative, FunctionDescriptor asyncFunc, FunctionDescriptor syncFunc); diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/JobWorker.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/JobWorker.java index 73c8f8a59..aff7c047c 100644 --- a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/JobWorker.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/JobWorker.java @@ -2,9 +2,6 @@ package org.ray.streaming.runtime.worker; import java.io.Serializable; import java.util.Map; - -import org.ray.api.Ray; -import org.ray.runtime.RayMultiWorkerNativeRuntime; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; import org.ray.streaming.runtime.core.graph.ExecutionGraph; import org.ray.streaming.runtime.core.graph.ExecutionNode; @@ -62,7 +59,6 @@ public class JobWorker implements Serializable { Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE); if (channelType.equals(Config.NATIVE_CHANNEL)) { transferHandler = new TransferHandler( - getNativeCoreWorker(), new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessage", "([B)V"), new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessageSync", "([B)[B"), new JavaFunctionDescriptor(JobWorker.class.getName(), "onReaderMessage", "([B)V"), @@ -148,13 +144,4 @@ public class JobWorker implements Serializable { public byte[] onWriterMessageSync(byte[] buffer) { return transferHandler.onWriterMessageSync(buffer); } - - private static long getNativeCoreWorker() { - long pointer = 0; - if (Ray.internal() instanceof RayMultiWorkerNativeRuntime) { - pointer = ((RayMultiWorkerNativeRuntime) Ray.internal()) - .getCurrentRuntime().getNativeCoreWorkerPointer(); - } - return pointer; - } } diff --git a/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/streamingqueue/Worker.java b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/streamingqueue/Worker.java index a852a6fd5..71364bbee 100644 --- a/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/streamingqueue/Worker.java +++ b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/streamingqueue/Worker.java @@ -10,7 +10,6 @@ import java.util.Random; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.id.ActorId; -import org.ray.runtime.RayMultiWorkerNativeRuntime; import org.ray.runtime.actor.NativeRayActor; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; import org.ray.streaming.runtime.transfer.ChannelID; @@ -29,8 +28,7 @@ public class Worker { protected TransferHandler transferHandler = null; public Worker() { - transferHandler = new TransferHandler(((RayMultiWorkerNativeRuntime) Ray.internal()) - .getCurrentRuntime().getNativeCoreWorkerPointer(), + transferHandler = new TransferHandler( new JavaFunctionDescriptor(Worker.class.getName(), "onWriterMessage", "([B)V"), new JavaFunctionDescriptor(Worker.class.getName(), diff --git a/streaming/python/includes/libstreaming.pxd b/streaming/python/includes/libstreaming.pxd index 0b1ad27c5..bf210b4a5 100644 --- a/streaming/python/includes/libstreaming.pxd +++ b/streaming/python/includes/libstreaming.pxd @@ -31,7 +31,6 @@ from ray.includes.unique_ids cimport ( CTaskID, CObjectID, ) -from ray.includes.libcoreworker cimport CCoreWorker cdef extern from "status.h" namespace "ray::streaming" nogil: cdef cppclass CStreamingStatus "ray::streaming::StreamingStatus": @@ -100,15 +99,13 @@ cdef extern from "message/message_bundle.h" namespace "ray::streaming" nogil: cdef extern from "queue/queue_client.h" namespace "ray::streaming" nogil: cdef cppclass CReaderClient "ray::streaming::ReaderClient": - CReaderClient(CCoreWorker *core_worker, - CRayFunction &async_func, + CReaderClient(CRayFunction &async_func, CRayFunction &sync_func) void OnReaderMessage(shared_ptr[CLocalMemoryBuffer] buffer); shared_ptr[CLocalMemoryBuffer] OnReaderMessageSync(shared_ptr[CLocalMemoryBuffer] buffer); cdef cppclass CWriterClient "ray::streaming::WriterClient": - CWriterClient(CCoreWorker *core_worker, - CRayFunction &async_func, + CWriterClient(CRayFunction &async_func, CRayFunction &sync_func) void OnWriterMessage(shared_ptr[CLocalMemoryBuffer] buffer); shared_ptr[CLocalMemoryBuffer] OnWriterMessageSync(shared_ptr[CLocalMemoryBuffer] buffer); diff --git a/streaming/python/includes/transfer.pxi b/streaming/python/includes/transfer.pxi index b57f30f10..c317137d6 100644 --- a/streaming/python/includes/transfer.pxi +++ b/streaming/python/includes/transfer.pxi @@ -19,14 +19,11 @@ from ray.includes.unique_ids cimport ( ) from ray._raylet cimport ( Buffer, - CoreWorker, ActorID, ObjectID, FunctionDescriptor, ) -from ray.includes.libcoreworker cimport CCoreWorker - cimport ray.streaming.includes.libstreaming as libstreaming from ray.streaming.includes.libstreaming cimport ( CStreamingStatus, @@ -52,16 +49,14 @@ cdef class ReaderClient: CReaderClient *client def __cinit__(self, - CoreWorker worker, FunctionDescriptor async_func, FunctionDescriptor sync_func): cdef: - CCoreWorker *core_worker = worker.core_worker.get() CRayFunction async_native_func CRayFunction sync_native_func async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor) sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor) - self.client = new CReaderClient(core_worker, async_native_func, sync_native_func) + self.client = new CReaderClient(async_native_func, sync_native_func) def __dealloc__(self): del self.client @@ -91,16 +86,14 @@ cdef class WriterClient: CWriterClient * client def __cinit__(self, - CoreWorker worker, FunctionDescriptor async_func, FunctionDescriptor sync_func): cdef: - CCoreWorker *core_worker = worker.core_worker.get() CRayFunction async_native_func CRayFunction sync_native_func async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor) sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor) - self.client = new CWriterClient(core_worker, async_native_func, sync_native_func) + self.client = new CWriterClient(async_native_func, sync_native_func) def __dealloc__(self): del self.client diff --git a/streaming/python/runtime/worker.py b/streaming/python/runtime/worker.py index a2e891876..3af4fcd78 100644 --- a/streaming/python/runtime/worker.py +++ b/streaming/python/runtime/worker.py @@ -48,7 +48,6 @@ class JobWorker(object): self.task_id, self.stream_processor)) if self.config.get(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL): - core_worker = ray.worker.global_worker.core_worker reader_async_func = PythonFunctionDescriptor( __name__, self.on_reader_message.__name__, self.__class__.__name__) @@ -56,7 +55,7 @@ class JobWorker(object): __name__, self.on_reader_message_sync.__name__, self.__class__.__name__) self.reader_client = _streaming.ReaderClient( - core_worker, reader_async_func, reader_sync_func) + reader_async_func, reader_sync_func) writer_async_func = PythonFunctionDescriptor( __name__, self.on_writer_message.__name__, self.__class__.__name__) @@ -64,7 +63,7 @@ class JobWorker(object): __name__, self.on_writer_message_sync.__name__, self.__class__.__name__) self.writer_client = _streaming.WriterClient( - core_worker, writer_async_func, writer_sync_func) + writer_async_func, writer_sync_func) self.task = self.create_stream_task() self.task.start() diff --git a/streaming/python/tests/test_direct_transfer.py b/streaming/python/tests/test_direct_transfer.py index dffdb0554..7bc389e93 100644 --- a/streaming/python/tests/test_direct_transfer.py +++ b/streaming/python/tests/test_direct_transfer.py @@ -12,21 +12,20 @@ from ray.streaming.config import Config @ray.remote class Worker: def __init__(self): - core_worker = ray.worker.global_worker.core_worker writer_async_func = PythonFunctionDescriptor( __name__, self.on_writer_message.__name__, self.__class__.__name__) writer_sync_func = PythonFunctionDescriptor( __name__, self.on_writer_message_sync.__name__, self.__class__.__name__) - self.writer_client = _streaming.WriterClient( - core_worker, writer_async_func, writer_sync_func) + self.writer_client = _streaming.WriterClient(writer_async_func, + writer_sync_func) reader_async_func = PythonFunctionDescriptor( __name__, self.on_reader_message.__name__, self.__class__.__name__) reader_sync_func = PythonFunctionDescriptor( __name__, self.on_reader_message_sync.__name__, self.__class__.__name__) - self.reader_client = _streaming.ReaderClient( - core_worker, reader_async_func, reader_sync_func) + self.reader_client = _streaming.ReaderClient(reader_async_func, + reader_sync_func) self.writer = None self.output_channel_id = None self.reader = None diff --git a/streaming/src/event_service.cc b/streaming/src/event_service.cc index 70bea209a..04377871c 100644 --- a/streaming/src/event_service.cc +++ b/streaming/src/event_service.cc @@ -1,6 +1,7 @@ #include #include "event_service.h" + namespace ray { namespace streaming { @@ -105,7 +106,11 @@ Event &EventQueue::Front() { } EventService::EventService(uint32_t event_size) - : event_queue_(std::make_shared(event_size)), stop_flag_(false) {} + : worker_id_(CoreWorkerProcess::IsInitialized() + ? CoreWorkerProcess::GetCoreWorker().GetWorkerID() + : WorkerID::Nil()), + event_queue_(std::make_shared(event_size)), + stop_flag_(false) {} EventService::~EventService() { stop_flag_ = true; // No need to join if loop thread has never been created. @@ -154,6 +159,9 @@ void EventService::Execute(Event &event) { } void EventService::LoopThreadHandler() { + if (CoreWorkerProcess::IsInitialized()) { + CoreWorkerProcess::SetCurrentThreadWorkerId(worker_id_); + } while (true) { if (stop_flag_) { break; diff --git a/streaming/src/event_service.h b/streaming/src/event_service.h index a4a89fe8b..c5f2b117e 100644 --- a/streaming/src/event_service.h +++ b/streaming/src/event_service.h @@ -7,6 +7,7 @@ #include #include "channel.h" +#include "ray/core_worker/core_worker.h" #include "ring_buffer.h" #include "util/streaming_util.h" @@ -127,6 +128,7 @@ class EventService { void LoopThreadHandler(); private: + WorkerID worker_id_; std::unordered_map event_handle_map_; std::shared_ptr event_queue_; std::shared_ptr loop_thread_; diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.cc b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.cc index 43f5e4a08..997404fdd 100644 --- a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.cc +++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.cc @@ -14,25 +14,19 @@ static std::shared_ptr JByteArrayToBuffer(JNIEnv *env, JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative( - JNIEnv *env, jobject this_obj, jlong core_worker_ptr, jobject async_func, - jobject sync_func) { + JNIEnv *env, jobject this_obj, jobject async_func, jobject sync_func) { auto ray_async_func = FunctionDescriptorToRayFunction(env, async_func); auto ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func); - auto *writer_client = - new WriterClient(reinterpret_cast(core_worker_ptr), - ray_async_func, ray_sync_func); + auto *writer_client = new WriterClient(ray_async_func, ray_sync_func); return reinterpret_cast(writer_client); } JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative( - JNIEnv *env, jobject this_obj, jlong core_worker_ptr, jobject async_func, - jobject sync_func) { + JNIEnv *env, jobject this_obj, jobject async_func, jobject sync_func) { ray::RayFunction ray_async_func = FunctionDescriptorToRayFunction(env, async_func); ray::RayFunction ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func); - auto *reader_client = - new ReaderClient(reinterpret_cast(core_worker_ptr), - ray_async_func, ray_sync_func); + auto *reader_client = new ReaderClient(ray_async_func, ray_sync_func); return reinterpret_cast(reader_client); } diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.h b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.h index 1cdc3e8ab..61a38f189 100644 --- a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.h +++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.h @@ -12,48 +12,58 @@ extern "C" { * Method: createWriterClientNative * Signature: (J)J */ -JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative - (JNIEnv *, jobject, jlong, jobject, jobject); +JNIEXPORT jlong JNICALL +Java_org_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative(JNIEnv *, + jobject, + jobject, + jobject); /* * Class: org_ray_streaming_runtime_transfer_TransferHandler * Method: createReaderClientNative * Signature: (J)J */ -JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative - (JNIEnv *, jobject, jlong, jobject, jobject); +JNIEXPORT jlong JNICALL +Java_org_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(JNIEnv *, + jobject, + jobject, + jobject); /* * Class: org_ray_streaming_runtime_transfer_TransferHandler * Method: handleWriterMessageNative * Signature: (J[B)V */ -JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative - (JNIEnv *, jobject, jlong, jbyteArray); +JNIEXPORT void JNICALL +Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative( + JNIEnv *, jobject, jlong, jbyteArray); /* * Class: org_ray_streaming_runtime_transfer_TransferHandler * Method: handleWriterMessageSyncNative * Signature: (J[B)[B */ -JNIEXPORT jbyteArray JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative - (JNIEnv *, jobject, jlong, jbyteArray); +JNIEXPORT jbyteArray JNICALL +Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative( + JNIEnv *, jobject, jlong, jbyteArray); /* * Class: org_ray_streaming_runtime_transfer_TransferHandler * Method: handleReaderMessageNative * Signature: (J[B)V */ -JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative - (JNIEnv *, jobject, jlong, jbyteArray); +JNIEXPORT void JNICALL +Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative( + JNIEnv *, jobject, jlong, jbyteArray); /* * Class: org_ray_streaming_runtime_transfer_TransferHandler * Method: handleReaderMessageSyncNative * Signature: (J[B)[B */ -JNIEXPORT jbyteArray JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative - (JNIEnv *, jobject, jlong, jbyteArray); +JNIEXPORT jbyteArray JNICALL +Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative( + JNIEnv *, jobject, jlong, jbyteArray); #ifdef __cplusplus } diff --git a/streaming/src/queue/queue_client.h b/streaming/src/queue/queue_client.h index a7d5171ca..5d191b3fc 100644 --- a/streaming/src/queue/queue_client.h +++ b/streaming/src/queue/queue_client.h @@ -14,16 +14,14 @@ namespace streaming { class ReaderClient { public: /// Construct a ReaderClient object. - /// \param[in] core_worker CoreWorker C++ pointer of current actor /// \param[in] async_func DataReader's raycall function descriptor to be called by /// DataWriter, asynchronous semantics \param[in] sync_func DataReader's raycall /// function descriptor to be called by DataWriter, synchronous semantics - ReaderClient(CoreWorker *core_worker, RayFunction &async_func, RayFunction &sync_func) - : core_worker_(core_worker) { + ReaderClient(RayFunction &async_func, RayFunction &sync_func) { DownstreamQueueMessageHandler::peer_async_function_ = async_func; DownstreamQueueMessageHandler::peer_sync_function_ = sync_func; downstream_handler_ = ray::streaming::DownstreamQueueMessageHandler::CreateService( - core_worker_, core_worker_->GetWorkerContext().GetCurrentActorID()); + CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID()); } /// Post buffer to downstream queue service, asynchronously. @@ -34,19 +32,17 @@ class ReaderClient { std::shared_ptr buffer); private: - CoreWorker *core_worker_; std::shared_ptr downstream_handler_; }; /// Interface of streaming queue for DataWriter. Similar to ReaderClient. class WriterClient { public: - WriterClient(CoreWorker *core_worker, RayFunction &async_func, RayFunction &sync_func) - : core_worker_(core_worker) { + WriterClient(RayFunction &async_func, RayFunction &sync_func) { UpstreamQueueMessageHandler::peer_async_function_ = async_func; UpstreamQueueMessageHandler::peer_sync_function_ = sync_func; upstream_handler_ = ray::streaming::UpstreamQueueMessageHandler::CreateService( - core_worker, core_worker_->GetWorkerContext().GetCurrentActorID()); + CoreWorkerProcess::GetCoreWorker().GetWorkerContext().GetCurrentActorID()); } void OnWriterMessage(std::shared_ptr buffer); @@ -54,7 +50,6 @@ class WriterClient { std::shared_ptr buffer); private: - CoreWorker *core_worker_; std::shared_ptr upstream_handler_; }; } // namespace streaming diff --git a/streaming/src/queue/queue_handler.cc b/streaming/src/queue/queue_handler.cc index d3e6cdbb7..6ac0cf0ed 100644 --- a/streaming/src/queue/queue_handler.cc +++ b/streaming/src/queue/queue_handler.cc @@ -85,8 +85,8 @@ std::shared_ptr QueueMessageHandler::GetOutTransport( void QueueMessageHandler::SetPeerActorID(const ObjectID &queue_id, const ActorID &actor_id) { actors_.emplace(queue_id, actor_id); - out_transports_.emplace( - queue_id, std::make_shared(core_worker_, actor_id)); + out_transports_.emplace(queue_id, + std::make_shared(actor_id)); } ActorID QueueMessageHandler::GetPeerActorID(const ObjectID &queue_id) { @@ -113,10 +113,9 @@ void QueueMessageHandler::Stop() { } std::shared_ptr UpstreamQueueMessageHandler::CreateService( - CoreWorker *core_worker, const ActorID &actor_id) { + const ActorID &actor_id) { if (nullptr == upstream_handler_) { - upstream_handler_ = - std::make_shared(core_worker, actor_id); + upstream_handler_ = std::make_shared(actor_id); } return upstream_handler_; } @@ -247,11 +246,9 @@ void UpstreamQueueMessageHandler::ReleaseAllUpQueues() { } std::shared_ptr -DownstreamQueueMessageHandler::CreateService(CoreWorker *core_worker, - const ActorID &actor_id) { +DownstreamQueueMessageHandler::CreateService(const ActorID &actor_id) { if (nullptr == downstream_handler_) { - downstream_handler_ = - std::make_shared(core_worker, actor_id); + downstream_handler_ = std::make_shared(actor_id); } return downstream_handler_; } diff --git a/streaming/src/queue/queue_handler.h b/streaming/src/queue/queue_handler.h index 0563b564b..0525ec8a6 100644 --- a/streaming/src/queue/queue_handler.h +++ b/streaming/src/queue/queue_handler.h @@ -24,16 +24,9 @@ namespace streaming { class QueueMessageHandler { public: /// Construct a QueueMessageHandler instance. - /// \param[in] core_worker CoreWorker C++ pointer of current actor, used to call Core - /// Worker's api. - /// For Python worker, the pointer can be obtained from - /// ray.worker.global_worker.core_worker; For Java worker, obtained from - /// RayNativeRuntime object through java reflection. /// \param[in] actor_id actor id of current actor. - QueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id) - : core_worker_(core_worker), - actor_id_(actor_id), - queue_dummy_work_(queue_service_) { + QueueMessageHandler(const ActorID &actor_id) + : actor_id_(actor_id), queue_dummy_work_(queue_service_) { Start(); } @@ -87,8 +80,6 @@ class QueueMessageHandler { void QueueThreadCallback() { queue_service_.run(); } protected: - /// CoreWorker C++ pointer of current actor - CoreWorker *core_worker_; /// actor_id actor id of current actor ActorID actor_id_; /// Helper function, parse message buffer to Message object. @@ -111,8 +102,7 @@ class QueueMessageHandler { class UpstreamQueueMessageHandler : public QueueMessageHandler { public: /// Construct a UpstreamQueueMessageHandler instance. - UpstreamQueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id) - : QueueMessageHandler(core_worker, actor_id) {} + UpstreamQueueMessageHandler(const ActorID &actor_id) : QueueMessageHandler(actor_id) {} /// Create a upstream queue. /// \param[in] queue_id queue id of the queue to be created. /// \param[in] peer_actor_id actor id of peer actor. @@ -140,7 +130,7 @@ class UpstreamQueueMessageHandler : public QueueMessageHandler { std::function)> callback) override; static std::shared_ptr CreateService( - CoreWorker *core_worker, const ActorID &actor_id); + const ActorID &actor_id); static std::shared_ptr GetService(); static RayFunction peer_sync_function_; @@ -157,8 +147,8 @@ class UpstreamQueueMessageHandler : public QueueMessageHandler { /// UpstreamQueueMessageHandler holds and manages all downstream queues of current actor. class DownstreamQueueMessageHandler : public QueueMessageHandler { public: - DownstreamQueueMessageHandler(CoreWorker *core_worker, const ActorID &actor_id) - : QueueMessageHandler(core_worker, actor_id) {} + DownstreamQueueMessageHandler(const ActorID &actor_id) + : QueueMessageHandler(actor_id) {} std::shared_ptr CreateDownstreamQueue(const ObjectID &queue_id, const ActorID &peer_actor_id); bool DownstreamQueueExists(const ObjectID &queue_id); @@ -178,7 +168,7 @@ class DownstreamQueueMessageHandler : public QueueMessageHandler { std::function)> callback); static std::shared_ptr CreateService( - CoreWorker *core_worker, const ActorID &actor_id); + const ActorID &actor_id); static std::shared_ptr GetService(); static RayFunction peer_sync_function_; static RayFunction peer_async_function_; diff --git a/streaming/src/queue/transport.cc b/streaming/src/queue/transport.cc index 95f5c11d4..6341fce0a 100644 --- a/streaming/src/queue/transport.cc +++ b/streaming/src/queue/transport.cc @@ -28,10 +28,9 @@ void Transport::SendInternal(std::shared_ptr buffer, args.emplace_back(TaskArg::PassByValue(std::make_shared( std::move(buffer), meta, std::vector(), true))); - STREAMING_CHECK(core_worker_ != nullptr); std::vector> results; - ray::Status st = - core_worker_->SubmitActorTask(peer_actor_id_, function, args, options, &return_ids); + ray::Status st = CoreWorkerProcess::GetCoreWorker().SubmitActorTask( + peer_actor_id_, function, args, options, &return_ids); if (!st.ok()) { STREAMING_LOG(ERROR) << "SubmitActorTask failed. " << st; } @@ -50,7 +49,8 @@ std::shared_ptr Transport::SendForResult( SendInternal(buffer, function, TASK_OPTION_RETURN_NUM_1, return_ids); std::vector> results; - Status get_st = core_worker_->Get(return_ids, timeout_ms, &results); + Status get_st = + CoreWorkerProcess::GetCoreWorker().Get(return_ids, timeout_ms, &results); if (!get_st.ok()) { STREAMING_LOG(ERROR) << "Get fail."; return nullptr; diff --git a/streaming/src/queue/transport.h b/streaming/src/queue/transport.h index 3f26754a4..8d702f41b 100644 --- a/streaming/src/queue/transport.h +++ b/streaming/src/queue/transport.h @@ -13,11 +13,10 @@ namespace streaming { class Transport { public: /// Construct a Transport object. - /// \param[in] core_worker CoreWorker C++ pointer of current actor, which we call direct - /// actor call interface with. /// \param[in] peer_actor_id actor id of peer actor. - Transport(CoreWorker *core_worker, const ActorID &peer_actor_id) - : core_worker_(core_worker), peer_actor_id_(peer_actor_id) {} + Transport(const ActorID &peer_actor_id) + : worker_id_(CoreWorkerProcess::GetCoreWorker().GetWorkerID()), + peer_actor_id_(peer_actor_id) {} virtual ~Transport() = default; /// Send buffer asynchronously, peer's `function` will be called. @@ -55,7 +54,7 @@ class Transport { std::vector &return_ids); private: - CoreWorker *core_worker_; + WorkerID worker_id_; ActorID peer_actor_id_; }; } // namespace streaming diff --git a/streaming/src/test/mock_actor.cc b/streaming/src/test/mock_actor.cc index 5762d81ff..21212ac28 100644 --- a/streaming/src/test/mock_actor.cc +++ b/streaming/src/test/mock_actor.cc @@ -20,11 +20,9 @@ namespace streaming { class StreamingQueueTestSuite { public: - StreamingQueueTestSuite(std::shared_ptr core_worker, ActorID &peer_actor_id, - std::vector queue_ids, + StreamingQueueTestSuite(ActorID &peer_actor_id, std::vector queue_ids, std::vector rescale_queue_ids) - : core_worker_(core_worker), - peer_actor_id_(peer_actor_id), + : peer_actor_id_(peer_actor_id), queue_ids_(queue_ids), rescale_queue_ids_(rescale_queue_ids) {} @@ -52,7 +50,6 @@ class StreamingQueueTestSuite { std::string current_test_; bool status_; std::shared_ptr executor_thread_; - std::shared_ptr core_worker_; ActorID peer_actor_id_; std::vector queue_ids_; std::vector rescale_queue_ids_; @@ -60,11 +57,9 @@ class StreamingQueueTestSuite { class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite { public: - StreamingQueueWriterTestSuite(std::shared_ptr core_worker, - ActorID &peer_actor_id, std::vector queue_ids, + StreamingQueueWriterTestSuite(ActorID &peer_actor_id, std::vector queue_ids, std::vector rescale_queue_ids) - : StreamingQueueTestSuite(core_worker, peer_actor_id, queue_ids, - rescale_queue_ids) { + : StreamingQueueTestSuite(peer_actor_id, queue_ids, rescale_queue_ids) { test_func_map_ = { {"streaming_writer_exactly_once_test", std::bind(&StreamingQueueWriterTestSuite::StreamingWriterExactlyOnceTest, @@ -135,11 +130,9 @@ class StreamingQueueWriterTestSuite : public StreamingQueueTestSuite { class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite { public: - StreamingQueueReaderTestSuite(std::shared_ptr core_worker, - ActorID peer_actor_id, std::vector queue_ids, + StreamingQueueReaderTestSuite(ActorID peer_actor_id, std::vector queue_ids, std::vector rescale_queue_ids) - : StreamingQueueTestSuite(core_worker, peer_actor_id, queue_ids, - rescale_queue_ids) { + : StreamingQueueTestSuite(peer_actor_id, queue_ids, rescale_queue_ids) { test_func_map_ = { {"streaming_writer_exactly_once_test", std::bind(&StreamingQueueReaderTestSuite::StreamingWriterExactlyOnceTest, @@ -247,7 +240,7 @@ class StreamingQueueReaderTestSuite : public StreamingQueueTestSuite { class TestSuiteFactory { public: static std::shared_ptr CreateTestSuite( - std::shared_ptr worker, std::shared_ptr message) { + std::shared_ptr message) { std::shared_ptr test_suite = nullptr; std::string suite_name = message->TestSuiteName(); queue::protobuf::StreamingQueueTestRole role = message->Role(); @@ -258,14 +251,14 @@ class TestSuiteFactory { if (role == queue::protobuf::StreamingQueueTestRole::WRITER) { if (suite_name == "StreamingWriterTest") { test_suite = std::make_shared( - worker, peer_actor_id, queue_ids, rescale_queue_ids); + peer_actor_id, queue_ids, rescale_queue_ids); } else { STREAMING_CHECK(false) << "unsurported suite_name: " << suite_name; } } else { if (suite_name == "StreamingWriterTest") { test_suite = std::make_shared( - worker, peer_actor_id, queue_ids, rescale_queue_ids); + peer_actor_id, queue_ids, rescale_queue_ids); } else { STREAMING_CHECK(false) << "unsupported suite_name: " << suite_name; } @@ -280,10 +273,30 @@ class StreamingWorker { StreamingWorker(const std::string &store_socket, const std::string &raylet_socket, int node_manager_port, const gcs::GcsClientOptions &gcs_options) : test_suite_(nullptr), peer_actor_handle_(nullptr) { - worker_ = std::make_shared( - WorkerType::WORKER, Language::PYTHON, store_socket, raylet_socket, - JobID::FromInt(1), gcs_options, "", "127.0.0.1", node_manager_port, - std::bind(&StreamingWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7)); + CoreWorkerOptions options = { + WorkerType::WORKER, // worker_type + Language::PYTHON, // langauge + store_socket, // store_socket + raylet_socket, // raylet_socket + JobID::FromInt(1), // job_id + gcs_options, // gcs_options + "", // log_dir + true, // install_failure_signal_handler + "127.0.0.1", // node_ip_address + node_manager_port, // node_manager_port + "", // driver_name + "", // stdout_file + "", // stderr_file + std::bind(&StreamingWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, + _7), // task_execution_callback + nullptr, // check_signals + nullptr, // gc_collect + nullptr, // get_lang_stack + true, // ref_counting_enabled + false, // is_local_mode + 1, // num_workers + }; + CoreWorkerProcess::Initialize(options); RayFunction reader_async_call_func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( @@ -298,16 +311,16 @@ class StreamingWorker { ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython("writer_sync_call_func", "", "", "")}; - reader_client_ = std::make_shared(worker_.get(), reader_async_call_func, - reader_sync_call_func); - writer_client_ = std::make_shared(worker_.get(), writer_async_call_func, - writer_sync_call_func); + reader_client_ = + std::make_shared(reader_async_call_func, reader_sync_call_func); + writer_client_ = + std::make_shared(writer_async_call_func, writer_sync_call_func); STREAMING_LOG(INFO) << "StreamingWorker constructor"; } - void StartExecutingTasks() { + void RunTaskExecutionLoop() { // Start executing tasks. - worker_->StartExecutingTasks(); + CoreWorkerProcess::RunTaskExecutionLoop(); } private: @@ -403,7 +416,8 @@ class StreamingWorker { STREAMING_LOG(INFO) << "Init message: " << message->ToString(); std::string actor_handle_serialized = message->ActorHandleSerialized(); - worker_->DeserializeAndRegisterActorHandle(actor_handle_serialized, ObjectID::Nil()); + CoreWorkerProcess::GetCoreWorker().DeserializeAndRegisterActorHandle( + actor_handle_serialized, ObjectID::Nil()); std::shared_ptr actor_handle(new ActorHandle(actor_handle_serialized)); STREAMING_CHECK(actor_handle != nullptr); STREAMING_LOG(INFO) << " actor id from handle: " << actor_handle->GetActorID(); @@ -421,12 +435,11 @@ class StreamingWorker { STREAMING_LOG(INFO) << "rescale queue: " << qid; } - test_suite_ = TestSuiteFactory::CreateTestSuite(worker_, message); + test_suite_ = TestSuiteFactory::CreateTestSuite(message); STREAMING_CHECK(test_suite_ != nullptr); } private: - std::shared_ptr worker_; std::shared_ptr reader_client_; std::shared_ptr writer_client_; std::shared_ptr test_thread_; @@ -446,6 +459,6 @@ int main(int argc, char **argv) { ray::gcs::GcsClientOptions gcs_options("127.0.0.1", 6379, ""); ray::streaming::StreamingWorker worker(store_socket, raylet_socket, node_manager_port, gcs_options); - worker.StartExecutingTasks(); + worker.RunTaskExecutionLoop(); return 0; } diff --git a/streaming/src/test/queue_tests_base.h b/streaming/src/test/queue_tests_base.h index 34e5455dd..fb4fa845f 100644 --- a/streaming/src/test/queue_tests_base.h +++ b/streaming/src/test/queue_tests_base.h @@ -153,11 +153,12 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { ASSERT_TRUE(system(("rm -rf " + raylet_socket_name + ".pid").c_str()) == 0); } - void InitWorker(CoreWorker &driver, ActorID &self_actor_id, ActorID &peer_actor_id, + void InitWorker(ActorID &self_actor_id, ActorID &peer_actor_id, const queue::protobuf::StreamingQueueTestRole role, const std::vector &queue_ids, const std::vector &rescale_queue_ids, std::string suite_name, std::string test_name, uint64_t param) { + auto &driver = CoreWorkerProcess::GetCoreWorker(); std::string forked_serialized_str; ObjectID actor_handle_id; Status st = driver.SerializeActorHandle(peer_actor_id, &forked_serialized_str, @@ -179,7 +180,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { RAY_CHECK_OK(driver.SubmitActorTask(self_actor_id, func, args, options, &return_ids)); } - void SubmitTestToActor(CoreWorker &driver, ActorID &actor_id, const std::string test) { + void SubmitTestToActor(ActorID &actor_id, const std::string test) { + auto &driver = CoreWorkerProcess::GetCoreWorker(); uint8_t data[8]; auto buffer = std::make_shared(data, 8, true); std::vector args; @@ -194,7 +196,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids)); } - bool CheckCurTest(CoreWorker &driver, ActorID &actor_id, const std::string test_name) { + bool CheckCurTest(ActorID &actor_id, const std::string test_name) { + auto &driver = CoreWorkerProcess::GetCoreWorker(); uint8_t data[8]; auto buffer = std::make_shared(data, 8, true); std::vector args; @@ -255,8 +258,7 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { return message->Status(); } - ActorID CreateActorHelper(CoreWorker &worker, - const std::unordered_map &resources, + ActorID CreateActorHelper(const std::unordered_map &resources, bool is_direct_call, uint64_t max_reconstructions) { std::unique_ptr actor_handle; @@ -277,8 +279,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { // Create an actor. ActorID actor_id; - RAY_CHECK_OK( - worker.CreateActor(func, args, actor_options, /*extension_data*/ "", &actor_id)); + RAY_CHECK_OK(CoreWorkerProcess::GetCoreWorker().CreateActor( + func, args, actor_options, /*extension_data*/ "", &actor_id)); return actor_id; } @@ -305,33 +307,54 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { } STREAMING_LOG(INFO) << "Sub process: writer."; - CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], - raylet_socket_names_[0], NextJobId(), gcs_options_, "", "127.0.0.1", - node_manager_port_, nullptr); + CoreWorkerOptions options = { + WorkerType::DRIVER, // worker_type + Language::PYTHON, // langauge + raylet_store_socket_names_[0], // store_socket + raylet_socket_names_[0], // raylet_socket + NextJobId(), // job_id + gcs_options_, // gcs_options + "", // log_dir + true, // install_failure_signal_handler + "127.0.0.1", // node_ip_address + node_manager_port_, // node_manager_port + "queue_tests", // driver_name + "", // stdout_file + "", // stderr_file + nullptr, // task_execution_callback + nullptr, // check_signals + nullptr, // gc_collect + nullptr, // get_lang_stack + true, // ref_counting_enabled + false, // is_local_mode + 1, // num_workers + }; + InitShutdownRAII core_worker_raii(CoreWorkerProcess::Initialize, + CoreWorkerProcess::Shutdown, options); // Create writer and reader actors std::unordered_map resources; - auto actor_id_writer = CreateActorHelper(driver, resources, true, 0); - auto actor_id_reader = CreateActorHelper(driver, resources, true, 0); + auto actor_id_writer = CreateActorHelper(resources, true, 0); + auto actor_id_reader = CreateActorHelper(resources, true, 0); - InitWorker(driver, actor_id_writer, actor_id_reader, + InitWorker(actor_id_writer, actor_id_reader, queue::protobuf::StreamingQueueTestRole::WRITER, queue_id_vec, rescale_queue_id_vec, suite_name, test_name, GetParam()); - InitWorker(driver, actor_id_reader, actor_id_writer, + InitWorker(actor_id_reader, actor_id_writer, queue::protobuf::StreamingQueueTestRole::READER, queue_id_vec, rescale_queue_id_vec, suite_name, test_name, GetParam()); std::this_thread::sleep_for(std::chrono::milliseconds(2000)); - SubmitTestToActor(driver, actor_id_writer, test_name); - SubmitTestToActor(driver, actor_id_reader, test_name); + SubmitTestToActor(actor_id_writer, test_name); + SubmitTestToActor(actor_id_reader, test_name); uint64_t slept_time_ms = 0; while (slept_time_ms < timeout_ms) { std::this_thread::sleep_for(std::chrono::milliseconds(5 * 1000)); STREAMING_LOG(INFO) << "Check test status."; - if (CheckCurTest(driver, actor_id_writer, test_name) && - CheckCurTest(driver, actor_id_reader, test_name)) { + if (CheckCurTest(actor_id_writer, test_name) && + CheckCurTest(actor_id_reader, test_name)) { STREAMING_LOG(INFO) << "Test Success, Exit."; return; }