diff --git a/java/api/src/main/java/org/ray/api/Ray.java b/java/api/src/main/java/org/ray/api/Ray.java index cdad95e16..4cf4ffac6 100644 --- a/java/api/src/main/java/org/ray/api/Ray.java +++ b/java/api/src/main/java/org/ray/api/Ray.java @@ -1,6 +1,7 @@ package org.ray.api; import java.util.List; +import java.util.concurrent.Callable; import org.ray.api.id.ObjectId; import org.ray.api.id.UniqueId; import org.ray.api.runtime.RayRuntime; @@ -117,6 +118,28 @@ public final class Ray extends RayCall { return runtime.wait(waitList, waitList.size(), Integer.MAX_VALUE); } + /** + * If users want to use Ray API in there own threads, they should wrap their {@link Runnable} + * objects with this method. + * + * @param runnable The runnable to wrap. + * @return The wrapped runnable. + */ + public static Runnable wrapRunnable(Runnable runnable) { + return runtime.wrapRunnable(runnable); + } + + /** + * If users want to use Ray API in there own threads, they should wrap their {@link Callable} + * objects with this method. + * + * @param callable The callable to wrap. + * @return The wrapped callable. + */ + public static Callable wrapCallable(Callable callable) { + return runtime.wrapCallable(callable); + } + /** * Get the underlying runtime instance. */ diff --git a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java index 2e14ca858..928df3221 100644 --- a/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/org/ray/api/options/ActorCreationOptions.java @@ -30,7 +30,7 @@ public class ActorCreationOptions extends BaseTaskOptions { private Map resources = new HashMap<>(); private int maxReconstructions = NO_RECONSTRUCTION; - private String jvmOptions = ""; + private String jvmOptions = null; public Builder setResources(Map resources) { this.resources = resources; diff --git a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java index 5a29c9a39..46d5ca842 100644 --- a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java @@ -1,6 +1,7 @@ package org.ray.api.runtime; import java.util.List; +import java.util.concurrent.Callable; import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.RayPyActor; @@ -141,4 +142,18 @@ public interface RayRuntime { */ RayPyActor createPyActor(String moduleName, String className, Object[] args, ActorCreationOptions options); + + /** + * Wrap a {@link Runnable} with necessary context capture. + * @param runnable The runnable to wrap. + * @return The wrapped runnable. + */ + Runnable wrapRunnable(Runnable runnable); + + /** + * Wrap a {@link Callable} with necessary context capture. + * @param callable The callable to wrap. + * @return The wrapped callable. + */ + Callable wrapCallable(Callable callable); } diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index e3132aa29..b3ea54e3f 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -4,6 +4,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.concurrent.Callable; import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.RayPyActor; @@ -59,11 +60,6 @@ public abstract class AbstractRayRuntime implements RayRuntime { runtimeContext = new RuntimeContextImpl(this); } - /** - * Start runtime. - */ - public abstract void start() throws Exception; - @Override public abstract void shutdown(); @@ -168,6 +164,16 @@ public abstract class AbstractRayRuntime implements RayRuntime { return (RayPyActor) createActorImpl(functionDescriptor, args, options); } + @Override + public Runnable wrapRunnable(Runnable runnable) { + return runnable; + } + + @Override + public Callable wrapCallable(Callable callable) { + return callable; + } + private RayObject callNormalFunction(FunctionDescriptor functionDescriptor, Object[] args, int numReturns, CallOptions options) { List functionArgs = ArgumentsBuilder diff --git a/java/runtime/src/main/java/org/ray/runtime/DefaultRayRuntimeFactory.java b/java/runtime/src/main/java/org/ray/runtime/DefaultRayRuntimeFactory.java index 7223fa28e..f95d7cfa4 100644 --- a/java/runtime/src/main/java/org/ray/runtime/DefaultRayRuntimeFactory.java +++ b/java/runtime/src/main/java/org/ray/runtime/DefaultRayRuntimeFactory.java @@ -4,11 +4,12 @@ import org.ray.api.runtime.RayRuntime; import org.ray.api.runtime.RayRuntimeFactory; import org.ray.runtime.config.RayConfig; import org.ray.runtime.config.RunMode; +import org.ray.runtime.generated.Common.WorkerType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * The default Ray runtime factory. It produces an instance of AbstractRayRuntime. + * The default Ray runtime factory. It produces an instance of RayRuntime. */ public class DefaultRayRuntimeFactory implements RayRuntimeFactory { @@ -18,14 +19,16 @@ public class DefaultRayRuntimeFactory implements RayRuntimeFactory { public RayRuntime createRayRuntime() { RayConfig rayConfig = RayConfig.create(); try { - AbstractRayRuntime runtime; + RayRuntime runtime; if (rayConfig.runMode == RunMode.SINGLE_PROCESS) { runtime = new RayDevRuntime(rayConfig); } else { - runtime = new RayNativeRuntime(rayConfig); + if (rayConfig.workerMode == WorkerType.DRIVER) { + runtime = new RayNativeRuntime(rayConfig); + } else { + runtime = new RayMultiWorkerNativeRuntime(rayConfig); + } } - - runtime.start(); return runtime; } catch (Exception e) { LOGGER.error("Failed to initialize ray runtime", e); diff --git a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java index 7653177a1..c96b811ce 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java @@ -11,14 +11,10 @@ import org.ray.runtime.task.TaskExecutor; public class RayDevRuntime extends AbstractRayRuntime { - public RayDevRuntime(RayConfig rayConfig) { - super(rayConfig); - } - private AtomicInteger jobCounter = new AtomicInteger(0); - @Override - public void start() { + public RayDevRuntime(RayConfig rayConfig) { + super(rayConfig); if (rayConfig.getJobId().isNil()) { rayConfig.setJobId(nextJobId()); } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayMultiWorkerNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayMultiWorkerNativeRuntime.java new file mode 100644 index 000000000..8c85048c9 --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/RayMultiWorkerNativeRuntime.java @@ -0,0 +1,192 @@ +package org.ray.runtime; + +import com.google.common.base.Preconditions; +import java.util.List; +import java.util.concurrent.Callable; +import org.ray.api.RayActor; +import org.ray.api.RayObject; +import org.ray.api.RayPyActor; +import org.ray.api.WaitResult; +import org.ray.api.function.RayFunc; +import org.ray.api.id.ObjectId; +import org.ray.api.id.UniqueId; +import org.ray.api.options.ActorCreationOptions; +import org.ray.api.options.CallOptions; +import org.ray.api.runtime.RayRuntime; +import org.ray.api.runtimecontext.RuntimeContext; +import org.ray.runtime.config.RayConfig; +import org.ray.runtime.config.RunMode; +import org.ray.runtime.generated.Common.WorkerType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This is a proxy runtime for multi-worker support. It holds multiple {@link RayNativeRuntime} + * instances and redirect calls to the correct one based on thread context. + */ +public class RayMultiWorkerNativeRuntime implements RayRuntime { + + private static final Logger LOGGER = LoggerFactory.getLogger(RayMultiWorkerNativeRuntime.class); + + /** + * The number of workers per worker process. + */ + private final int numWorkers; + /** + * The worker threads. + */ + private final Thread[] threads; + /** + * The {@link RayNativeRuntime} instances of workers. + */ + private final RayNativeRuntime[] runtimes; + /** + * The {@link RayNativeRuntime} instance of current thread. + */ + private final ThreadLocal currentThreadRuntime = new ThreadLocal<>(); + + public RayMultiWorkerNativeRuntime(RayConfig rayConfig) { + Preconditions.checkState( + rayConfig.runMode == RunMode.CLUSTER && rayConfig.workerMode == WorkerType.WORKER); + Preconditions.checkState(rayConfig.numWorkersPerProcess > 0, + "numWorkersPerProcess must be greater than 0."); + numWorkers = rayConfig.numWorkersPerProcess; + runtimes = new RayNativeRuntime[numWorkers]; + threads = new Thread[numWorkers]; + + LOGGER.info("Starting {} workers.", numWorkers); + + for (int i = 0; i < numWorkers; i++) { + final int workerIndex = i; + threads[i] = new Thread(() -> { + RayNativeRuntime runtime = new RayNativeRuntime(rayConfig); + runtimes[workerIndex] = runtime; + currentThreadRuntime.set(runtime); + runtime.run(); + }); + } + } + + public void run() { + for (int i = 0; i < numWorkers; i++) { + threads[i].start(); + } + for (int i = 0; i < numWorkers; i++) { + try { + threads[i].join(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + + @Override + public void shutdown() { + for (int i = 0; i < numWorkers; i++) { + runtimes[i].shutdown(); + } + for (int i = 0; i < numWorkers; i++) { + try { + threads[i].join(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + + public RayNativeRuntime getCurrentRuntime() { + RayNativeRuntime currentRuntime = currentThreadRuntime.get(); + Preconditions.checkNotNull(currentRuntime, + "RayRuntime is not set on current thread." + + " If you want to use Ray API in your own threads," + + " please wrap your `Runnable`s or `Callable`s with" + + " `Ray.wrapRunnable` or `Ray.wrapCallable`."); + return currentRuntime; + } + + @Override + public RayObject put(T obj) { + return getCurrentRuntime().put(obj); + } + + @Override + public T get(ObjectId objectId) { + return getCurrentRuntime().get(objectId); + } + + @Override + public List get(List objectIds) { + return getCurrentRuntime().get(objectIds); + } + + @Override + public WaitResult wait(List> waitList, int numReturns, int timeoutMs) { + return getCurrentRuntime().wait(waitList, numReturns, timeoutMs); + } + + @Override + public void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { + getCurrentRuntime().free(objectIds, localOnly, deleteCreatingTasks); + } + + @Override + public void setResource(String resourceName, double capacity, UniqueId nodeId) { + getCurrentRuntime().setResource(resourceName, capacity, nodeId); + } + + @Override + public RayObject call(RayFunc func, Object[] args, CallOptions options) { + return getCurrentRuntime().call(func, args, options); + } + + @Override + public RayObject call(RayFunc func, RayActor actor, Object[] args) { + return getCurrentRuntime().call(func, actor, args); + } + + @Override + public RayActor createActor(RayFunc actorFactoryFunc, Object[] args, + ActorCreationOptions options) { + return getCurrentRuntime().createActor(actorFactoryFunc, args, options); + } + + @Override + public RuntimeContext getRuntimeContext() { + return getCurrentRuntime().getRuntimeContext(); + } + + @Override + public RayObject callPy(String moduleName, String functionName, Object[] args, + CallOptions options) { + return getCurrentRuntime().callPy(moduleName, functionName, args, options); + } + + @Override + public RayObject callPy(RayPyActor pyActor, String functionName, Object[] args) { + return getCurrentRuntime().callPy(pyActor, functionName, args); + } + + @Override + public RayPyActor createPyActor(String moduleName, String className, Object[] args, + ActorCreationOptions options) { + return getCurrentRuntime().createPyActor(moduleName, className, args, options); + } + + @Override + public Runnable wrapRunnable(Runnable runnable) { + RayNativeRuntime runtime = getCurrentRuntime(); + return () -> { + currentThreadRuntime.set(runtime); + runnable.run(); + }; + } + + @Override + public Callable wrapCallable(Callable callable) { + RayNativeRuntime runtime = getCurrentRuntime(); + return () -> { + currentThreadRuntime.set(runtime); + return callable.call(); + }; + } +} diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index 28a0d0828..99908636d 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -2,9 +2,12 @@ package org.ray.runtime; import com.google.common.base.Preconditions; import com.google.common.base.Strings; +import java.io.File; +import java.io.IOException; import java.lang.reflect.Field; import java.util.HashMap; import java.util.Map; +import org.apache.commons.io.FileUtils; import org.ray.api.id.JobId; import org.ray.runtime.config.RayConfig; import org.ray.runtime.context.NativeWorkerContext; @@ -46,15 +49,20 @@ public final class RayNativeRuntime extends AbstractRayRuntime { } LOGGER.debug("Native libraries loaded."); } - nativeSetup(RayConfig.create().logDir); + + RayConfig globalRayConfig = RayConfig.create(); + resetLibraryPath(globalRayConfig); + + try { + FileUtils.forceMkdir(new File(globalRayConfig.logDir)); + } catch (IOException e) { + throw new RuntimeException("Failed to create the log directory.", e); + } + nativeSetup(globalRayConfig.logDir); Runtime.getRuntime().addShutdownHook(new Thread(RayNativeRuntime::nativeShutdownHook)); } - public RayNativeRuntime(RayConfig rayConfig) { - super(rayConfig); - } - - protected void resetLibraryPath() { + private static void resetLibraryPath(RayConfig rayConfig) { if (rayConfig.libraryPath.isEmpty()) { return; } @@ -81,10 +89,11 @@ public final class RayNativeRuntime extends AbstractRayRuntime { } } - @Override - public void start() { + public RayNativeRuntime(RayConfig rayConfig) { + super(rayConfig); + // Reset library path at runtime. - resetLibraryPath(); + resetLibraryPath(rayConfig); if (rayConfig.getRedisAddress() == null) { manager = new RunManager(rayConfig); diff --git a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java index 05e18ed27..3be4bee26 100644 --- a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java @@ -62,6 +62,8 @@ public class RayConfig { */ public final int numberExecThreadsForDevRuntime; + public final int numWorkersPerProcess; + private void validate() { if (workerMode == WorkerType.WORKER) { Preconditions.checkArgument(redisAddress != null, @@ -171,6 +173,8 @@ public class RayConfig { // Number of threads that execute tasks. numberExecThreadsForDevRuntime = config.getInt("ray.dev-runtime.execution-parallelism"); + numWorkersPerProcess = config.getInt("ray.raylet.config.num_workers_per_process_java"); + // Validate config. validate(); LOGGER.debug("Created config: {}", this); diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java index 97f6dd135..7434e3890 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java @@ -320,10 +320,13 @@ public class RunManager { cmd.add("-Dray.redis.password=" + rayConfig.headRedisPassword); } + // Number of workers per Java worker process + cmd.add("-Dray.raylet.config.num_workers_per_process_java=RAY_WORKER_NUM_WORKERS_PLACEHOLDER"); + cmd.addAll(rayConfig.jvmParameters); // jvm options - cmd.add("RAY_WORKER_OPTION_0"); + cmd.add("RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_0"); // Main class cmd.add(WORKER_CLASS); diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java b/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java index 698c14973..fad4ec2aa 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java @@ -1,6 +1,8 @@ package org.ray.runtime.runner.worker; import org.ray.api.Ray; +import org.ray.api.runtime.RayRuntime; +import org.ray.runtime.RayMultiWorkerNativeRuntime; import org.ray.runtime.RayNativeRuntime; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -20,7 +22,14 @@ public class DefaultWorker { }); Ray.init(); LOGGER.info("Worker started."); - ((RayNativeRuntime)Ray.internal()).run(); + RayRuntime runtime = Ray.internal(); + if (runtime instanceof RayNativeRuntime) { + ((RayNativeRuntime)runtime).run(); + } else if (runtime instanceof RayMultiWorkerNativeRuntime) { + ((RayMultiWorkerNativeRuntime)runtime).run(); + } else { + throw new RuntimeException("Unknown RayRuntime: " + runtime); + } } catch (Exception e) { LOGGER.error("Failed to start worker.", e); } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index 110c178f7..c74932c8f 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -5,7 +5,9 @@ import java.util.List; import org.ray.api.Ray; import org.ray.api.RayObject; import org.ray.api.id.ObjectId; +import org.ray.api.runtime.RayRuntime; import org.ray.runtime.AbstractRayRuntime; +import org.ray.runtime.RayMultiWorkerNativeRuntime; import org.ray.runtime.object.NativeRayObject; import org.ray.runtime.object.ObjectStore; import org.ray.runtime.util.Serializer; @@ -41,7 +43,11 @@ public class ArgumentsBuilder { } else { byte[] serialized = Serializer.encode(arg); if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) { - id = ((AbstractRayRuntime) Ray.internal()).getObjectStore() + RayRuntime runtime = Ray.internal(); + if (runtime instanceof RayMultiWorkerNativeRuntime) { + runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime(); + } + id = ((AbstractRayRuntime) runtime).getObjectStore() .put(new NativeRayObject(serialized, null)); } else { data = serialized; diff --git a/java/runtime/src/main/resources/ray.default.conf b/java/runtime/src/main/resources/ray.default.conf index c5fd12e92..3da6089d7 100644 --- a/java/runtime/src/main/resources/ray.default.conf +++ b/java/runtime/src/main/resources/ray.default.conf @@ -87,6 +87,7 @@ ray { // See src/ray/ray_config_def.h for options. config { + num_workers_per_process_java: 10 } } diff --git a/java/test/src/main/java/org/ray/api/TestUtils.java b/java/test/src/main/java/org/ray/api/TestUtils.java index 3badb1104..a03bd627f 100644 --- a/java/test/src/main/java/org/ray/api/TestUtils.java +++ b/java/test/src/main/java/org/ray/api/TestUtils.java @@ -1,8 +1,11 @@ package org.ray.api; +import com.google.common.base.Preconditions; import java.util.function.Supplier; import org.ray.api.annotation.RayRemote; +import org.ray.api.runtime.RayRuntime; import org.ray.runtime.AbstractRayRuntime; +import org.ray.runtime.RayMultiWorkerNativeRuntime; import org.ray.runtime.config.RunMode; import org.testng.Assert; import org.testng.SkipException; @@ -12,8 +15,7 @@ public class TestUtils { private static final int WAIT_INTERVAL_MS = 5; public static void skipTestUnderSingleProcess() { - AbstractRayRuntime runtime = (AbstractRayRuntime)Ray.internal(); - if (runtime.getRayConfig().runMode == RunMode.SINGLE_PROCESS) { + if (getRuntime().getRayConfig().runMode == RunMode.SINGLE_PROCESS) { throw new SkipException("This test doesn't work under single-process mode."); } } @@ -62,4 +64,13 @@ public class TestUtils { RayObject obj = Ray.call(TestUtils::hi); Assert.assertEquals(obj.get(), "hi"); } + + public static AbstractRayRuntime getRuntime() { + RayRuntime runtime = Ray.internal(); + if (runtime instanceof RayMultiWorkerNativeRuntime) { + runtime = ((RayMultiWorkerNativeRuntime) runtime).getCurrentRuntime(); + } + Preconditions.checkState(runtime instanceof AbstractRayRuntime); + return (AbstractRayRuntime) runtime; + } } diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java index 0f4e08fa6..a8870558e 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java @@ -11,7 +11,6 @@ import org.ray.api.TestUtils; import org.ray.api.annotation.RayRemote; import org.ray.api.exception.UnreconstructableException; import org.ray.api.id.UniqueId; -import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.actor.NativeRayActor; import org.ray.runtime.object.NativeRayObject; import org.testng.Assert; @@ -117,7 +116,7 @@ public class ActorTest extends BaseTest { Ray.internal().free(ImmutableList.of(value.getId()), false, false); // Wait until the object is deleted, because the above free operation is async. while (true) { - NativeRayObject result = ((AbstractRayRuntime) Ray.internal()).getObjectStore() + NativeRayObject result = TestUtils.getRuntime().getObjectStore() .getRaw(ImmutableList.of(value.getId()), 0).get(0); if (result == null) { break; diff --git a/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java b/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java index 939372b96..0603a917e 100644 --- a/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java +++ b/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java @@ -1,5 +1,6 @@ package org.ray.api.test; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.io.File; @@ -63,7 +64,7 @@ public abstract class BaseMultiLanguageTest { // Start ray cluster. String workerOptions = " -classpath " + System.getProperty("java.class.path"); - final List startCommand = ImmutableList.of( + List startCommand = ImmutableList.of( "ray", "start", "--head", @@ -74,6 +75,13 @@ public abstract class BaseMultiLanguageTest { "--include-java", "--java-worker-options=" + workerOptions ); + String numWorkersPerProcessJava = System + .getProperty("ray.raylet.config.num_workers_per_process_java"); + if (!Strings.isNullOrEmpty(numWorkersPerProcessJava)) { + startCommand = ImmutableList.builder().addAll(startCommand) + .add(String.format("--internal-config={\"num_workers_per_process_java\": %s}", + numWorkersPerProcessJava)).build(); + } if (!executeCommand(startCommand, 10, getRayStartEnv())) { throw new RuntimeException("Couldn't start ray cluster."); } diff --git a/java/test/src/main/java/org/ray/api/test/FailureTest.java b/java/test/src/main/java/org/ray/api/test/FailureTest.java index b47b010ae..4dbcbcf4b 100644 --- a/java/test/src/main/java/org/ray/api/test/FailureTest.java +++ b/java/test/src/main/java/org/ray/api/test/FailureTest.java @@ -127,6 +127,7 @@ public class FailureTest extends BaseTest { TestUtils.skipTestUnderSingleProcess(); List> badFunctions = Arrays.asList(FailureTest::badFunc, FailureTest::badFunc2); + TestUtils.warmUpCluster(); for (RayFunc0 badFunc : badFunctions) { RayObject obj1 = Ray.call(badFunc); RayObject obj2 = Ray.call(FailureTest::slowFunc); diff --git a/java/test/src/main/java/org/ray/api/test/GcsClientTest.java b/java/test/src/main/java/org/ray/api/test/GcsClientTest.java index 04b08b64b..bd778f91b 100644 --- a/java/test/src/main/java/org/ray/api/test/GcsClientTest.java +++ b/java/test/src/main/java/org/ray/api/test/GcsClientTest.java @@ -2,11 +2,9 @@ package org.ray.api.test; import com.google.common.base.Preconditions; import java.util.List; -import org.ray.api.Ray; import org.ray.api.TestUtils; import org.ray.api.id.JobId; import org.ray.api.runtimecontext.NodeInfo; -import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.config.RayConfig; import org.ray.runtime.gcs.GcsClient; import org.testng.Assert; @@ -29,10 +27,10 @@ public class GcsClientTest extends BaseTest { @Test public void testGetAllNodeInfo() { TestUtils.skipTestUnderSingleProcess(); - RayConfig config = ((AbstractRayRuntime)Ray.internal()).getRayConfig(); + RayConfig config = TestUtils.getRuntime().getRayConfig(); Preconditions.checkNotNull(config); - GcsClient gcsClient = ((AbstractRayRuntime)Ray.internal()).getGcsClient(); + GcsClient gcsClient = TestUtils.getRuntime().getGcsClient(); List allNodeInfo = gcsClient.getAllNodeInfo(); Assert.assertEquals(allNodeInfo.size(), 1); Assert.assertEquals(allNodeInfo.get(0).nodeAddress, config.nodeIp); @@ -43,11 +41,11 @@ public class GcsClientTest extends BaseTest { @Test public void testNextJob() { TestUtils.skipTestUnderSingleProcess(); - RayConfig config = ((AbstractRayRuntime)Ray.internal()).getRayConfig(); + RayConfig config = TestUtils.getRuntime().getRayConfig(); // The value of job id of this driver in cluster should be `1L`. Assert.assertEquals(config.getJobId(), JobId.fromInt(1)); - GcsClient gcsClient = ((AbstractRayRuntime)Ray.internal()).getGcsClient(); + GcsClient gcsClient = TestUtils.getRuntime().getGcsClient(); for (int i = 2; i < 100; ++i) { Assert.assertEquals(gcsClient.nextJobId(), JobId.fromInt(i)); } diff --git a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java index 86fdc1b94..ce2fb2452 100644 --- a/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java +++ b/java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java @@ -54,19 +54,22 @@ public class MultiThreadingTest extends BaseTest { } @RayRemote - public ActorId getCurrentActorId() { - final ActorId[] result = new ActorId[1]; - Thread thread = new Thread(() -> { - result[0] = Ray.getRuntimeContext().getCurrentActorId(); - }); + public ActorId getCurrentActorId() throws Exception { + final Object[] result = new Object[1]; + Thread thread = new Thread(Ray.wrapRunnable(() -> { + try { + result[0] = Ray.getRuntimeContext().getCurrentActorId(); + } catch (Exception e) { + result[0] = e; + } + })); thread.start(); - try { - thread.join(); - } catch (InterruptedException e) { - throw new RuntimeException(e); + thread.join(); + if (result[0] instanceof Exception) { + throw (Exception) result[0]; } Assert.assertEquals(result[0], actorId); - return result[0]; + return (ActorId) result[0]; } } @@ -147,13 +150,13 @@ public class MultiThreadingTest extends BaseTest { try { List> futures = new ArrayList<>(); for (int i = 0; i < NUM_THREADS; i++) { - Callable task = () -> { + Callable task = Ray.wrapCallable(() -> { for (int j = 0; j < numRepeats; j++) { TimeUnit.MILLISECONDS.sleep(1); testCase.run(); } return "ok"; - }; + }); futures.add(service.submit(task)); } for (Future future : futures) { diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java index e0d18d9a5..395209ad5 100644 --- a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java +++ b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java @@ -7,7 +7,6 @@ import org.ray.api.RayObject; import org.ray.api.TestUtils; import org.ray.api.annotation.RayRemote; import org.ray.api.id.TaskId; -import org.ray.runtime.AbstractRayRuntime; import org.testng.Assert; import org.testng.annotations.Test; @@ -26,7 +25,7 @@ public class PlasmaFreeTest extends BaseTest { Ray.internal().free(ImmutableList.of(helloId.getId()), true, false); final boolean result = TestUtils.waitForCondition(() -> - ((AbstractRayRuntime) Ray.internal()).getObjectStore() + TestUtils.getRuntime().getObjectStore() .getRaw(ImmutableList.of(helloId.getId()), 0).get(0) == null, 50); Assert.assertTrue(result); } @@ -40,7 +39,7 @@ public class PlasmaFreeTest extends BaseTest { TaskId taskId = TaskId.fromBytes(Arrays.copyOf(helloId.getId().getBytes(), TaskId.LENGTH)); final boolean result = TestUtils.waitForCondition( - () -> !(((AbstractRayRuntime) Ray.internal()).getGcsClient()) + () -> !TestUtils.getRuntime().getGcsClient() .rayletTaskExistsInGcs(taskId), 50); Assert.assertTrue(result); } diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java index fc0a06e91..e0655e052 100644 --- a/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java +++ b/java/test/src/main/java/org/ray/api/test/PlasmaStoreTest.java @@ -1,11 +1,8 @@ package org.ray.api.test; -import java.util.Collections; import org.ray.api.Ray; import org.ray.api.TestUtils; import org.ray.api.id.ObjectId; -import org.ray.runtime.AbstractRayRuntime; -import org.ray.runtime.object.NativeRayObject; import org.ray.runtime.object.ObjectStore; import org.testng.Assert; import org.testng.annotations.Test; @@ -16,8 +13,7 @@ public class PlasmaStoreTest extends BaseTest { public void testPutWithDuplicateId() { TestUtils.skipTestUnderSingleProcess(); ObjectId objectId = ObjectId.fromRandom(); - AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal(); - ObjectStore objectStore = runtime.getObjectStore(); + ObjectStore objectStore = TestUtils.getRuntime().getObjectStore(); objectStore.put("1", objectId); Assert.assertEquals(Ray.get(objectId), "1"); objectStore.put("2", objectId); diff --git a/java/test/src/main/java/org/ray/api/test/RayCallTest.java b/java/test/src/main/java/org/ray/api/test/RayCallTest.java index 37def8733..db496e2b8 100644 --- a/java/test/src/main/java/org/ray/api/test/RayCallTest.java +++ b/java/test/src/main/java/org/ray/api/test/RayCallTest.java @@ -6,9 +6,9 @@ import java.io.Serializable; import java.util.List; import java.util.Map; import org.ray.api.Ray; +import org.ray.api.TestUtils; import org.ray.api.annotation.RayRemote; import org.ray.api.id.ObjectId; -import org.ray.runtime.AbstractRayRuntime; import org.testng.Assert; import org.testng.annotations.Test; @@ -80,7 +80,7 @@ public class RayCallTest extends BaseTest { @RayRemote private static void testNoReturn(ObjectId objectId) { // Put an object in object store to inform driver that this function is executing. - ((AbstractRayRuntime) Ray.internal()).getObjectStore().put(1, objectId); + TestUtils.getRuntime().getObjectStore().put(1, objectId); } /** diff --git a/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java b/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java index 729815c23..0cb8c6b23 100644 --- a/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java +++ b/java/test/src/main/java/org/ray/api/test/RaySerializerTest.java @@ -2,8 +2,8 @@ package org.ray.api.test; import org.ray.api.Ray; import org.ray.api.RayPyActor; +import org.ray.api.TestUtils; import org.ray.api.id.ObjectId; -import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.object.NativeRayObject; import org.ray.runtime.object.ObjectStore; import org.testng.Assert; @@ -14,7 +14,7 @@ public class RaySerializerTest extends BaseMultiLanguageTest { @Test public void testSerializePyActor() { RayPyActor pyActor = Ray.createPyActor("test", "RaySerializerTest"); - ObjectStore objectStore = ((AbstractRayRuntime) Ray.internal()).getObjectStore(); + ObjectStore objectStore = TestUtils.getRuntime().getObjectStore(); NativeRayObject nativeRayObject = objectStore.serialize(pyActor); RayPyActor result = (RayPyActor) objectStore .deserialize(nativeRayObject, ObjectId.fromRandom()); diff --git a/python/ray/includes/ray_config.pxd b/python/ray/includes/ray_config.pxd index 41adec160..4d3a76ec7 100644 --- a/python/ray/includes/ray_config.pxd +++ b/python/ray/includes/ray_config.pxd @@ -76,7 +76,9 @@ cdef extern from "ray/common/ray_config.h" nogil: uint64_t object_manager_default_chunk_size() const - int num_workers_per_process() const + int num_workers_per_process_python() const + + int num_workers_per_process_java() const int64_t max_task_lease_timeout_ms() const diff --git a/python/ray/includes/ray_config.pxi b/python/ray/includes/ray_config.pxi index 03f9a8df5..7171884ee 100644 --- a/python/ray/includes/ray_config.pxi +++ b/python/ray/includes/ray_config.pxi @@ -142,8 +142,12 @@ cdef class Config: return RayConfig.instance().object_manager_default_chunk_size() @staticmethod - def num_workers_per_process(): - return RayConfig.instance().num_workers_per_process() + def num_workers_per_process_python(): + return RayConfig.instance().num_workers_per_process_python() + + @staticmethod + def num_workers_per_process_java(): + return RayConfig.instance().num_workers_per_process_java() @staticmethod def max_task_lease_timeout_ms(): diff --git a/python/ray/services.py b/python/ray/services.py index b15fcbb84..5b806f8dc 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1181,12 +1181,15 @@ def build_java_worker_command( command += "-Dray.home={} ".format(RAY_HOME) command += "-Dray.log-dir={} ".format(os.path.join(session_dir, "logs")) + command += ("-Dray.raylet.config.num_workers_per_process_java=" + + "RAY_WORKER_NUM_WORKERS_PLACEHOLDER ") + if java_worker_options: # Put `java_worker_options` in the last, so it can overwrite the # above options. command += java_worker_options + " " - command += "RAY_WORKER_OPTION_0 " + command += "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_0 " command += "org.ray.runtime.runner.worker.DefaultWorker" return command diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index 5a82fed2d..e117d4bac 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -21,6 +21,9 @@ constexpr char kObjectTablePrefix[] = "ObjectTable"; /// Prefix for the task table keys in redis. constexpr char kTaskTablePrefix[] = "TaskTable"; -constexpr char kWorkerDynamicOptionPlaceholderPrefix[] = "RAY_WORKER_OPTION_"; +constexpr char kWorkerDynamicOptionPlaceholderPrefix[] = + "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_"; + +constexpr char kWorkerNumWorkersPlaceholder[] = "RAY_WORKER_NUM_WORKERS_PLACEHOLDER"; #endif // RAY_CONSTANTS_H_ diff --git a/src/ray/common/ray_config_def.h b/src/ray/common/ray_config_def.h index d0a6000c6..e0720006a 100644 --- a/src/ray/common/ray_config_def.h +++ b/src/ray/common/ray_config_def.h @@ -137,8 +137,11 @@ RAY_CONFIG(int, object_manager_repeated_push_delay_ms, 60000) /// chunks exceeds the number of available sending threads. RAY_CONFIG(uint64_t, object_manager_default_chunk_size, 1000000) -/// Number of workers per process -RAY_CONFIG(int, num_workers_per_process, 1) +/// Number of workers per Python worker process +RAY_CONFIG(int, num_workers_per_process_python, 1) + +/// Number of workers per Java worker process +RAY_CONFIG(int, num_workers_per_process_java, 10) /// Maximum timeout in milliseconds within which a task lease must be renewed. RAY_CONFIG(int64_t, max_task_lease_timeout_ms, 60000) diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc index ca219ddae..beaa000b2 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc @@ -84,10 +84,12 @@ inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, jobject java_resources = env->GetObjectField(actorCreationOptions, java_base_task_options_resources); resources = ToResources(env, java_resources); - std::string jvm_options = JavaStringToNativeString( - env, (jstring)env->GetObjectField(actorCreationOptions, - java_actor_creation_options_jvm_options)); - dynamic_worker_options.emplace_back(jvm_options); + jstring java_jvm_options = (jstring)env->GetObjectField( + actorCreationOptions, java_actor_creation_options_jvm_options); + if (java_jvm_options) { + std::string jvm_options = JavaStringToNativeString(env, java_jvm_options); + dynamic_worker_options.emplace_back(jvm_options); + } } ray::ActorCreationOptions action_creation_options{ diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index badcbea5b..15f17c2f5 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -85,8 +85,8 @@ message ActorCreationTaskSpec { uint64 max_actor_reconstructions = 3; // The dynamic options used in the worker command when starting a worker process for // an actor creation task. If the list isn't empty, the options will be used to replace - // the placeholder strings (`RAY_WORKER_OPTION_0`, `RAY_WORKER_OPTION_1`, etc) in the - // worker command. + // the placeholder strings (`RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_0`, + // `RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_1`, etc) in the worker command. repeated string dynamic_worker_options = 4; } diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index af68db41d..9a67ecfa8 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -104,8 +104,6 @@ int main(int argc, char *argv[]) { node_manager_config.node_manager_address = node_ip_address; node_manager_config.node_manager_port = node_manager_port; node_manager_config.num_initial_workers = num_initial_workers; - node_manager_config.num_workers_per_process = - RayConfig::instance().num_workers_per_process(); node_manager_config.maximum_startup_concurrency = maximum_startup_concurrency; if (!python_worker_command.empty()) { diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index d74be41c6..fa113cce8 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -82,9 +82,8 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, object_manager_profile_timer_(io_service), initial_config_(config), local_available_resources_(config.resource_config), - worker_pool_(config.num_initial_workers, config.num_workers_per_process, - config.maximum_startup_concurrency, gcs_client_, - config.worker_commands), + worker_pool_(config.num_initial_workers, config.maximum_startup_concurrency, + gcs_client_, config.worker_commands), scheduling_policy_(local_queues_), reconstruction_policy_( io_service_, diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 931853134..2afd90460 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -43,8 +43,6 @@ struct NodeManagerConfig { int node_manager_port; /// The initial number of workers to create. int num_initial_workers; - /// The number of workers per process. - int num_workers_per_process; /// The maximum number of workers that can be started concurrently by a /// worker pool. int maximum_startup_concurrency; diff --git a/src/ray/raylet/object_manager_integration_test.cc b/src/ray/raylet/object_manager_integration_test.cc index 73525f2e4..128a07e90 100644 --- a/src/ray/raylet/object_manager_integration_test.cc +++ b/src/ray/raylet/object_manager_integration_test.cc @@ -39,7 +39,6 @@ class TestObjectManagerBase : public ::testing::Test { node_manager_config.resource_config = ray::raylet::ResourceSet(std::move(static_resource_conf)); node_manager_config.num_initial_workers = 0; - node_manager_config.num_workers_per_process = 1; // Use a default worker that can execute empty tasks with dependencies. std::vector py_worker_command; py_worker_command.push_back("python"); diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index f2e77c6cf..ade829870 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -40,17 +40,13 @@ namespace ray { namespace raylet { /// A constructor that initializes a worker pool with -/// (num_worker_processes * num_workers_per_process) workers for each language. -WorkerPool::WorkerPool(int num_worker_processes, int num_workers_per_process, - int maximum_startup_concurrency, +/// (num_worker_processes * states_by_lang_[language].num_workers_per_process) workers for +/// each language. +WorkerPool::WorkerPool(int num_worker_processes, int maximum_startup_concurrency, std::shared_ptr gcs_client, const WorkerCommandMap &worker_commands) - : num_workers_per_process_(num_workers_per_process), - multiple_for_warning_(std::max(num_worker_processes, maximum_startup_concurrency)), - maximum_startup_concurrency_(maximum_startup_concurrency), - last_warning_multiple_(0), + : maximum_startup_concurrency_(maximum_startup_concurrency), gcs_client_(std::move(gcs_client)) { - RAY_CHECK(num_workers_per_process > 0) << "num_workers_per_process must be positive."; RAY_CHECK(maximum_startup_concurrency > 0); // Ignore SIGCHLD signals. If we don't do this, then worker processes will // become zombies instead of dying gracefully. @@ -58,6 +54,25 @@ WorkerPool::WorkerPool(int num_worker_processes, int num_workers_per_process, for (const auto &entry : worker_commands) { // Initialize the pool state for this language. auto &state = states_by_lang_[entry.first]; + switch (entry.first) { + case Language::PYTHON: + state.num_workers_per_process = + RayConfig::instance().num_workers_per_process_python(); + break; + case Language::JAVA: + state.num_workers_per_process = + RayConfig::instance().num_workers_per_process_java(); + break; + default: + RAY_LOG(FATAL) << "The number of workers per process for " + << Language_Name(entry.first) << " worker is not set."; + } + RAY_CHECK(state.num_workers_per_process > 0) + << "Number of workers per process of language " << Language_Name(entry.first) + << " must be positive."; + state.multiple_for_warning = + std::max(num_worker_processes, maximum_startup_concurrency) * + state.num_workers_per_process; // Set worker command for this language. state.worker_command = entry.second; RAY_CHECK(!state.worker_command.empty()) << "Worker command must not be empty."; @@ -119,9 +134,17 @@ int WorkerPool::StartWorkerProcess(const Language &language, << state.idle_actor.size() << " actor workers, and " << state.idle.size() << " non-actor workers"; + int workers_to_start; + if (dynamic_options.empty()) { + workers_to_start = state.num_workers_per_process; + } else { + workers_to_start = 1; + } + // Extract pointers from the worker command to pass into execvp. std::vector worker_command_args; size_t dynamic_option_index = 0; + bool num_workers_arg_replaced = false; for (auto const &token : state.worker_command) { const auto option_placeholder = kWorkerDynamicOptionPlaceholderPrefix + std::to_string(dynamic_option_index); @@ -135,9 +158,22 @@ int WorkerPool::StartWorkerProcess(const Language &language, ++dynamic_option_index; } } else { - worker_command_args.push_back(token); + size_t num_workers_index = token.find(kWorkerNumWorkersPlaceholder); + if (num_workers_index != std::string::npos) { + std::string arg = token; + worker_command_args.push_back(arg.replace(num_workers_index, + strlen(kWorkerNumWorkersPlaceholder), + std::to_string(workers_to_start))); + num_workers_arg_replaced = true; + } else { + worker_command_args.push_back(token); + } } } + RAY_CHECK(num_workers_arg_replaced || state.num_workers_per_process == 1) + << "Expect to start " << state.num_workers_per_process << " workers per " + << Language_Name(language) << " worker process. But the " + << kWorkerNumWorkersPlaceholder << "placeholder is not found in worker command."; pid_t pid = StartProcess(worker_command_args); if (pid < 0) { @@ -145,9 +181,9 @@ int WorkerPool::StartWorkerProcess(const Language &language, RAY_LOG(FATAL) << "Failed to fork worker process: " << strerror(errno); } else if (pid > 0) { // Parent process case. - RAY_LOG(DEBUG) << "Started worker process with pid " << pid; - state.starting_worker_processes.emplace( - std::make_pair(pid, num_workers_per_process_)); + RAY_LOG(DEBUG) << "Started worker process of " << workers_to_start + << " worker(s) with pid " << pid; + state.starting_worker_processes.emplace(pid, workers_to_start); return pid; } return -1; @@ -355,27 +391,30 @@ std::vector> WorkerPool::GetWorkersRunningTasksForJob( } void WorkerPool::WarnAboutSize() { - int64_t num_workers_started_or_registered = 0; for (const auto &entry : states_by_lang_) { + auto state = entry.second; + int64_t num_workers_started_or_registered = 0; num_workers_started_or_registered += - static_cast(entry.second.registered_workers.size()); - num_workers_started_or_registered += - static_cast(entry.second.starting_worker_processes.size()); - } - int64_t multiple = num_workers_started_or_registered / multiple_for_warning_; - std::stringstream warning_message; - if (multiple >= 3 && multiple > last_warning_multiple_) { - // Push an error message to the user if the worker pool tells us that it is - // getting too big. - last_warning_multiple_ = multiple; - warning_message << "WARNING: " << num_workers_started_or_registered - << " workers have been started. This could be a result of using " - << "a large number of actors, or it could be a consequence of " - << "using nested tasks " - << "(see https://github.com/ray-project/ray/issues/3644) for " - << "some a discussion of workarounds."; - RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - JobID::Nil(), "worker_pool_large", warning_message.str(), current_time_ms())); + static_cast(state.registered_workers.size()); + for (const auto &starting_process : state.starting_worker_processes) { + num_workers_started_or_registered += starting_process.second; + } + int64_t multiple = num_workers_started_or_registered / state.multiple_for_warning; + std::stringstream warning_message; + if (multiple >= 3 && multiple > state.last_warning_multiple) { + // Push an error message to the user if the worker pool tells us that it is + // getting too big. + state.last_warning_multiple = multiple; + warning_message << "WARNING: " << num_workers_started_or_registered << " " + << Language_Name(entry.first) + << " workers have been started. This could be a result of using " + << "a large number of actors, or it could be a consequence of " + << "using nested tasks " + << "(see https://github.com/ray-project/ray/issues/3644) for " + << "some a discussion of workarounds."; + RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( + JobID::Nil(), "worker_pool_large", warning_message.str(), current_time_ms())); + } } } @@ -390,8 +429,10 @@ std::string WorkerPool::DebugString() const { std::stringstream result; result << "WorkerPool:"; for (const auto &entry : states_by_lang_) { - result << "\n- num workers: " << entry.second.registered_workers.size(); - result << "\n- num drivers: " << entry.second.registered_drivers.size(); + result << "\n- num " << Language_Name(entry.first) + << " workers: " << entry.second.registered_workers.size(); + result << "\n- num " << Language_Name(entry.first) + << " drivers: " << entry.second.registered_drivers.size(); } return result.str(); } diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 4d6f4b307..fe89140f1 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -33,14 +33,12 @@ class WorkerPool { /// and add them to the pool. /// /// \param num_worker_processes The number of worker processes to start, per language. - /// \param num_workers_per_process The number of workers per process. /// \param maximum_startup_concurrency The maximum number of worker processes /// that can be started in parallel (typically this should be set to the number of CPU /// resources on the machine). /// \param worker_commands The commands used to start the worker process, grouped by /// language. - WorkerPool(int num_worker_processes, int num_workers_per_process, - int maximum_startup_concurrency, + WorkerPool(int num_worker_processes, int maximum_startup_concurrency, std::shared_ptr gcs_client, const WorkerCommandMap &worker_commands); @@ -134,7 +132,7 @@ class WorkerPool { protected: /// Asynchronously start a new worker process. Once the worker process has /// registered with an external server, the process should create and - /// register num_workers_per_process_ workers, then add them to the pool. + /// register num_workers_per_process workers, then add them to the pool. /// Failure to start the worker process is a fatal error. If too many workers /// are already being started, then this function will return without starting /// any workers. @@ -159,6 +157,8 @@ class WorkerPool { struct State { /// The commands and arguments used to start the worker process std::vector worker_command; + /// The number of workers per process. + int num_workers_per_process; /// The pool of dedicated workers for actor creation tasks /// with prefix or suffix worker command. std::unordered_map> idle_dedicated_workers; @@ -179,10 +179,14 @@ class WorkerPool { std::unordered_map dedicated_workers_to_tasks; /// A map for speeding up looking up the pending worker for the given task. std::unordered_map tasks_to_dedicated_workers; + /// We'll push a warning to the user every time a multiple of this many + /// worker processes has been started. + int multiple_for_warning; + /// The last size at which a warning about the number of registered workers + /// was generated. + int64_t last_warning_multiple; }; - /// The number of workers per process. - int num_workers_per_process_; /// Pool states per language. std::unordered_map> states_by_lang_; @@ -191,14 +195,8 @@ class WorkerPool { /// for a given language. State &GetStateForLanguage(const Language &language); - /// We'll push a warning to the user every time a multiple of this many - /// workers has been started. - int multiple_for_warning_; - /// The maximum number of workers that can be started concurrently. + /// The maximum number of worker processes that can be started concurrently. int maximum_startup_concurrency_; - /// The last size at which a warning about the number of registered workers - /// was generated. - int64_t last_warning_multiple_; /// A client connection to the GCS. std::shared_ptr gcs_client_; }; diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 34ec1619a..c3f8702b9 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -15,13 +15,20 @@ int MAXIMUM_STARTUP_CONCURRENCY = 5; class WorkerPoolMock : public WorkerPool { public: WorkerPoolMock() - : WorkerPoolMock({{Language::PYTHON, {"dummy_py_worker_command"}}, - {Language::JAVA, {"dummy_java_worker_command"}}}) {} + : WorkerPoolMock( + {{Language::PYTHON, + {"dummy_py_worker_command", "--foo=RAY_WORKER_NUM_WORKERS_PLACEHOLDER"}}, + {Language::JAVA, + {"dummy_java_worker_command", + "--foo=RAY_WORKER_NUM_WORKERS_PLACEHOLDER"}}}) {} explicit WorkerPoolMock(const WorkerCommandMap &worker_commands) - : WorkerPool(0, NUM_WORKERS_PER_PROCESS, MAXIMUM_STARTUP_CONCURRENCY, nullptr, - worker_commands), - last_worker_pid_(0) {} + : WorkerPool(0, MAXIMUM_STARTUP_CONCURRENCY, nullptr, worker_commands), + last_worker_pid_(0) { + for (auto &entry : states_by_lang_) { + entry.second.num_workers_per_process = NUM_WORKERS_PER_PROCESS; + } + } ~WorkerPoolMock() { // Avoid killing real processes @@ -150,16 +157,38 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) { } TEST_F(WorkerPoolTest, StartupWorkerCount) { - int desired_initial_worker_count_per_language = 20; - for (int i = 0; i < desired_initial_worker_count_per_language; i++) { - worker_pool_.StartWorkerProcess(Language::PYTHON); - worker_pool_.StartWorkerProcess(Language::JAVA); + std::string num_workers_arg = + std::string("--foo=") + std::to_string(NUM_WORKERS_PER_PROCESS); + std::vector languages = {Language::PYTHON, Language::JAVA}; + std::vector> worker_commands = { + {{"dummy_py_worker_command", num_workers_arg}, + {"dummy_java_worker_command", num_workers_arg}}}; + int desired_initial_worker_process_count_per_language = MAXIMUM_STARTUP_CONCURRENCY + 1; + int expected_worker_process_count = MAXIMUM_STARTUP_CONCURRENCY * languages.size(); + pid_t last_started_worker_process = 0; + for (int i = 0; i < desired_initial_worker_process_count_per_language; i++) { + for (size_t j = 0; j < languages.size(); j++) { + worker_pool_.StartWorkerProcess(languages[j]); + ASSERT_TRUE(worker_pool_.NumWorkerProcessesStarting() <= + expected_worker_process_count); + if (last_started_worker_process != worker_pool_.LastStartedWorkerProcess()) { + last_started_worker_process = worker_pool_.LastStartedWorkerProcess(); + const auto &real_command = + worker_pool_.GetWorkerCommand(worker_pool_.LastStartedWorkerProcess()); + ASSERT_EQ(real_command, worker_commands[j]); + } else { + ASSERT_TRUE(worker_pool_.NumWorkerProcessesStarting() == + expected_worker_process_count); + ASSERT_TRUE(static_cast(i * languages.size() + j) >= + expected_worker_process_count); + } + } } - // Check that number of starting worker processes equals to - // maximum_startup_concurrency_ * 2. (because we started both python and java workers) - ASSERT_EQ( - worker_pool_.NumWorkerProcessesStarting(), - /* Provided in constructor of WorkerPoolMock */ MAXIMUM_STARTUP_CONCURRENCY * 2); + // Check number of starting worker processes + ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), expected_worker_process_count); + ASSERT_TRUE(worker_pool_.NumWorkerProcessesStarting() < + static_cast(desired_initial_worker_process_count_per_language * + languages.size())); } TEST_F(WorkerPoolTest, HandleWorkerPushPop) { @@ -232,7 +261,9 @@ TEST_F(WorkerPoolTest, PopWorkersOfMultipleLanguages) { TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { const std::vector java_worker_command = { - "RAY_WORKER_OPTION_0", "dummy_java_worker_command", "RAY_WORKER_OPTION_1"}; + "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_0", "dummy_java_worker_command", + "--foo=RAY_WORKER_NUM_WORKERS_PLACEHOLDER", + "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_1"}; SetWorkerCommands({{Language::PYTHON, {"dummy_py_worker_command"}}, {Language::JAVA, java_worker_command}}); @@ -243,8 +274,9 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { worker_pool_.StartWorkerProcess(Language::JAVA, task_spec.DynamicWorkerOptions()); const auto real_command = worker_pool_.GetWorkerCommand(worker_pool_.LastStartedWorkerProcess()); - ASSERT_EQ(real_command, std::vector( - {"test_op_0", "dummy_java_worker_command", "test_op_1"})); + ASSERT_EQ(real_command, + std::vector( + {"test_op_0", "dummy_java_worker_command", "--foo=1", "test_op_1"})); } } // namespace raylet