From 630e48967d2400580c879600534b77e4af5e27ff Mon Sep 17 00:00:00 2001 From: Kai Yang Date: Sun, 15 Mar 2020 12:03:38 +0800 Subject: [PATCH] [Java] Allow passing internal config from raylet to Java worker (#7532) --- .../org/ray/runtime/RayNativeRuntime.java | 4 +- .../org/ray/runtime/config/RayConfig.java | 10 +- .../org/ray/runtime/runner/RunManager.java | 13 ++- .../ray/api/test/BaseMultiLanguageTest.java | 12 +- .../org/ray/api/test/RayletConfigTest.java | 39 +++++++ python/ray/services.py | 4 +- src/ray/common/constants.h | 2 +- src/ray/core_worker/lib/java/jni_utils.h | 33 ++++++ .../java/org_ray_runtime_RayNativeRuntime.cc | 14 ++- .../java/org_ray_runtime_RayNativeRuntime.h | 5 +- ...rg_ray_runtime_task_NativeTaskSubmitter.cc | 31 ++--- src/ray/raylet/main.cc | 1 + src/ray/raylet/node_manager.cc | 2 +- src/ray/raylet/node_manager.h | 2 + src/ray/raylet/worker_pool.cc | 59 +++++++--- src/ray/raylet/worker_pool.h | 4 + src/ray/raylet/worker_pool_test.cc | 110 +++++++++--------- 17 files changed, 224 insertions(+), 121 deletions(-) create mode 100644 java/test/src/main/java/org/ray/api/test/RayletConfigTest.java diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index 1255b31f1..5f69f6a4d 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -61,7 +61,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { } catch (IOException e) { throw new RuntimeException("Failed to create the log directory.", e); } - nativeSetup(rayConfig.logDir); + nativeSetup(rayConfig.logDir, rayConfig.rayletConfigParameters); Runtime.getRuntime().addShutdownHook(new Thread(RayNativeRuntime::nativeShutdownHook)); } @@ -193,7 +193,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { private static native void nativeDestroyCoreWorker(long nativeCoreWorkerPointer); - private static native void nativeSetup(String logDir); + private static native void nativeSetup(String logDir, Map rayletConfigParameters); private static native void nativeShutdownHook(); diff --git a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java index d6e76d1eb..1ba6246c7 100644 --- a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java @@ -10,7 +10,7 @@ import com.typesafe.config.ConfigValue; import java.io.File; import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; -import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; @@ -67,7 +67,7 @@ public class RayConfig { public String rayletSocketName; private int nodeManagerPort; - public final List rayletConfigParameters; + public final Map rayletConfigParameters; public final String jobResourcePath; public final String pythonWorkerCommand; @@ -204,11 +204,11 @@ public class RayConfig { } // Raylet parameters. - rayletConfigParameters = new ArrayList<>(); + rayletConfigParameters = new HashMap<>(); Config rayletConfig = config.getConfig("ray.raylet.config"); for (Map.Entry entry : rayletConfig.entrySet()) { - String parameter = entry.getKey() + "," + entry.getValue().unwrapped(); - rayletConfigParameters.add(parameter); + Object value = entry.getValue().unwrapped(); + rayletConfigParameters.put(entry.getKey(), value == null ? "" : value.toString()); } // Job resource path. diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java index 514b0bb59..13851f119 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java @@ -240,7 +240,10 @@ public class RunManager { gcsServerFile.getAbsolutePath(), String.format("--redis_address=%s", rayConfig.getRedisIp()), String.format("--redis_port=%d", rayConfig.getRedisPort()), - String.format("--config_list=%s", String.join(",", rayConfig.rayletConfigParameters)), + String.format("--config_list=%s", + rayConfig.rayletConfigParameters.entrySet().stream() + .map(entry -> entry.getKey() + "," + entry.getValue()).collect(Collectors + .joining(","))), String.format("--redis_password=%s", redisPasswordOption) ); startProcess(command, null, "gcs_server"); @@ -316,7 +319,9 @@ public class RunManager { String.format("--maximum_startup_concurrency=%d", maximumStartupConcurrency), String.format("--static_resource_list=%s", ResourceUtil.getResourcesStringFromMap(rayConfig.resources)), - String.format("--config_list=%s", String.join(",", rayConfig.rayletConfigParameters)), + String.format("--config_list=%s", rayConfig.rayletConfigParameters.entrySet().stream() + .map(entry -> entry.getKey() + "," + entry.getValue()) + .collect(Collectors.joining(","))), String.format("--python_worker_command=%s", buildPythonWorkerCommand()), String.format("--java_worker_command=%s", buildWorkerCommand()), String.format("--redis_password=%s", redisPasswordOption) @@ -378,8 +383,8 @@ public class RunManager { cmd.add("-Dray.redis.password=" + rayConfig.headRedisPassword); } - // Number of workers per Java worker process - cmd.add("-Dray.raylet.config.num_workers_per_process_java=RAY_WORKER_NUM_WORKERS_PLACEHOLDER"); + + cmd.add("RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER"); cmd.addAll(rayConfig.jvmParameters); diff --git a/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java b/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java index 40c5cd5ab..015b5af4e 100644 --- a/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java +++ b/java/test/src/main/java/org/ray/api/test/BaseMultiLanguageTest.java @@ -1,6 +1,5 @@ package org.ray.api.test; -import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.gson.Gson; @@ -13,6 +12,7 @@ import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.Stream; import org.ray.api.Ray; +import org.ray.runtime.config.RayConfig; import org.ray.runtime.util.NetworkUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -90,15 +90,9 @@ public abstract class BaseMultiLanguageTest { String.format("--node-manager-port=%s", nodeManagerPort), "--load-code-from-local", "--include-java", - "--java-worker-options=" + workerOptions + "--java-worker-options=" + workerOptions, + "--internal-config=" + new Gson().toJson(RayConfig.create().rayletConfigParameters) ); - String numWorkersPerProcessJava = System - .getProperty("ray.raylet.config.num_workers_per_process_java"); - if (!Strings.isNullOrEmpty(numWorkersPerProcessJava)) { - startCommand = ImmutableList.builder().addAll(startCommand) - .add(String.format("--internal-config={\"num_workers_per_process_java\": %s}", - numWorkersPerProcessJava)).build(); - } if (!executeCommand(startCommand, 10, getRayStartEnv())) { throw new RuntimeException("Couldn't start ray cluster."); } diff --git a/java/test/src/main/java/org/ray/api/test/RayletConfigTest.java b/java/test/src/main/java/org/ray/api/test/RayletConfigTest.java new file mode 100644 index 000000000..932d443b0 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/RayletConfigTest.java @@ -0,0 +1,39 @@ +package org.ray.api.test; + +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.TestUtils; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class RayletConfigTest extends BaseTest { + + private static final String RAY_CONFIG_KEY = "num_workers_per_process_java"; + private static final String RAY_CONFIG_VALUE = "2"; + + @BeforeClass + public void beforeClass() { + System.setProperty("ray.raylet.config." + RAY_CONFIG_KEY, RAY_CONFIG_VALUE); + } + + @AfterClass + public void afterClass() { + System.clearProperty("ray.raylet.config." + RAY_CONFIG_KEY); + } + + public static class TestActor { + + public String getConfigValue() { + return TestUtils.getRuntime().getRayConfig().rayletConfigParameters.get(RAY_CONFIG_KEY); + } + } + + @Test + public void testRayletConfigPassThrough() { + RayActor actor = Ray.createActor(TestActor::new); + String configValue = actor.call(TestActor::getConfigValue).get(); + Assert.assertEquals(configValue, RAY_CONFIG_VALUE); + } +} diff --git a/python/ray/services.py b/python/ray/services.py index a03ea4baa..95ee2c484 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -1366,11 +1366,11 @@ def build_java_worker_command( pairs.append(("ray.home", RAY_HOME)) pairs.append(("ray.log-dir", os.path.join(session_dir, "logs"))) pairs.append(("ray.session-dir", session_dir)) - pairs.append(("ray.raylet.config.num_workers_per_process_java", - "RAY_WORKER_NUM_WORKERS_PLACEHOLDER")) command = ["java"] + ["-D{}={}".format(*pair) for pair in pairs] + command += ["RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER"] + # Add ray jars path to java classpath ray_jars = os.path.join(get_ray_jars_dir(), "*") if java_worker_options is None: diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index fd1004b84..f28cd8299 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -41,6 +41,6 @@ constexpr char kTaskTablePrefix[] = "TaskTable"; constexpr char kWorkerDynamicOptionPlaceholderPrefix[] = "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_"; -constexpr char kWorkerNumWorkersPlaceholder[] = "RAY_WORKER_NUM_WORKERS_PLACEHOLDER"; +constexpr char kWorkerRayletConfigPlaceholder[] = "RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER"; #endif // RAY_CONSTANTS_H_ diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index 02d733019..d3a318b74 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -300,6 +300,39 @@ inline jobject NativeIdVectorToJavaByteArrayList(JNIEnv *env, }); } +/// Convert a Java Map to a C++ std::unordered_map +template +inline std::unordered_map JavaMapToNativeMap( + JNIEnv *env, jobject java_map, + const std::function &key_converter, + const std::function &value_converter) { + std::unordered_map native_map; + if (java_map) { + jobject entry_set = env->CallObjectMethod(java_map, java_map_entry_set); + RAY_CHECK_JAVA_EXCEPTION(env); + jobject iterator = env->CallObjectMethod(entry_set, java_set_iterator); + RAY_CHECK_JAVA_EXCEPTION(env); + while (env->CallBooleanMethod(iterator, java_iterator_has_next)) { + RAY_CHECK_JAVA_EXCEPTION(env); + jobject map_entry = env->CallObjectMethod(iterator, java_iterator_next); + RAY_CHECK_JAVA_EXCEPTION(env); + auto java_key = (jstring)env->CallObjectMethod(map_entry, java_map_entry_get_key); + RAY_CHECK_JAVA_EXCEPTION(env); + key_type key = key_converter(env, java_key); + auto java_value = env->CallObjectMethod(map_entry, java_map_entry_get_value); + value_type value = value_converter(env, java_value); + native_map.emplace(key, value); + env->DeleteLocalRef(java_key); + env->DeleteLocalRef(java_value); + env->DeleteLocalRef(map_entry); + } + RAY_CHECK_JAVA_EXCEPTION(env); + env->DeleteLocalRef(iterator); + env->DeleteLocalRef(entry_set); + } + return native_map; +} + /// Convert a C++ ray::Buffer to a Java byte array. inline jbyteArray NativeBufferToJavaByteArray(JNIEnv *env, const std::shared_ptr buffer) { diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc index 555d8e522..8679049b3 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc @@ -128,15 +128,23 @@ JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeDestroyCoreWo delete core_worker; } -JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup(JNIEnv *env, - jclass, - jstring logDir) { +JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup( + JNIEnv *env, jclass, jstring logDir, jobject rayletConfigParameters) { std::string log_dir = JavaStringToNativeString(env, logDir); ray::RayLog::StartRayLog("java_worker", ray::RayLogLevel::INFO, log_dir); // TODO (kfstorm): We can't InstallFailureSignalHandler here, because JVM already // installed its own signal handler. It's possible to fix this by chaining signal // handlers. But it's not easy. See // https://docs.oracle.com/javase/9/troubleshoot/handle-signals-and-exceptions.htm. + auto raylet_config = JavaMapToNativeMap( + env, rayletConfigParameters, + [](JNIEnv *env, jobject java_key) { + return JavaStringToNativeString(env, (jstring)java_key); + }, + [](JNIEnv *env, jobject java_value) { + return JavaStringToNativeString(env, (jstring)java_value); + }); + RayConfig::instance().initialize(raylet_config); } JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeShutdownHook(JNIEnv *, diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h index 821a1fadc..e6dadede5 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h +++ b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.h @@ -49,10 +49,11 @@ Java_org_ray_runtime_RayNativeRuntime_nativeDestroyCoreWorker(JNIEnv *, jclass, /* * Class: org_ray_runtime_RayNativeRuntime * Method: nativeSetup - * Signature: (Ljava/lang/String;)V + * Signature: (Ljava/lang/String;Ljava/util/Map;)V */ JNIEXPORT void JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeSetup(JNIEnv *, jclass, - jstring); + jstring, + jobject); /* * Class: org_ray_runtime_RayNativeRuntime diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc index 6c248d082..f03ca920c 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc @@ -65,27 +65,16 @@ inline std::vector ToTaskArgs(JNIEnv *env, jobject args) { inline std::unordered_map ToResources(JNIEnv *env, jobject java_resources) { std::unordered_map resources; - if (java_resources) { - jobject entry_set = env->CallObjectMethod(java_resources, java_map_entry_set); - RAY_CHECK_JAVA_EXCEPTION(env); - jobject iterator = env->CallObjectMethod(entry_set, java_set_iterator); - RAY_CHECK_JAVA_EXCEPTION(env); - while (env->CallBooleanMethod(iterator, java_iterator_has_next)) { - RAY_CHECK_JAVA_EXCEPTION(env); - jobject map_entry = env->CallObjectMethod(iterator, java_iterator_next); - RAY_CHECK_JAVA_EXCEPTION(env); - auto java_key = (jstring)env->CallObjectMethod(map_entry, java_map_entry_get_key); - RAY_CHECK_JAVA_EXCEPTION(env); - std::string key = JavaStringToNativeString(env, java_key); - auto java_value = env->CallObjectMethod(map_entry, java_map_entry_get_value); - RAY_CHECK_JAVA_EXCEPTION(env); - double value = env->CallDoubleMethod(java_value, java_double_double_value); - RAY_CHECK_JAVA_EXCEPTION(env); - resources.emplace(key, value); - } - RAY_CHECK_JAVA_EXCEPTION(env); - } - return resources; + return JavaMapToNativeMap( + env, java_resources, + [](JNIEnv *env, jobject java_key) { + return JavaStringToNativeString(env, (jstring)java_key); + }, + [](JNIEnv *env, jobject java_value) { + double value = env->CallDoubleMethod(java_value, java_double_double_value); + RAY_CHECK_JAVA_EXCEPTION(env); + return value; + }); } inline ray::TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject callOptions) { diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index dd49d57c9..7066f79cc 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -113,6 +113,7 @@ int main(int argc, char *argv[]) { static_resource_conf[resource_name] = std::stod(resource_quantity); } + node_manager_config.raylet_config = raylet_config; node_manager_config.resource_config = ray::ResourceSet(std::move(static_resource_conf)); RAY_LOG(DEBUG) << "Starting raylet with static resource configuration: " << node_manager_config.resource_config.ToString(); diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index d5c92f63b..3bd06d0de 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -106,7 +106,7 @@ NodeManager::NodeManager(boost::asio::io_service &io_service, local_available_resources_(config.resource_config), worker_pool_( io_service, config.num_initial_workers, config.maximum_startup_concurrency, - gcs_client_, config.worker_commands, + gcs_client_, config.worker_commands, config.raylet_config, /*starting_worker_timeout_callback=*/ [this]() { this->DispatchTasks(this->local_queues_.GetReadyTasksByClass()); }), scheduling_policy_(local_queues_), diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 1061933c1..a01a81566 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -84,6 +84,8 @@ struct NodeManagerConfig { std::string temp_dir; /// The path of this ray session dir. std::string session_dir; + /// The raylet config list of this node. + std::unordered_map raylet_config; }; class NodeManager : public rpc::NodeManagerServiceHandler { diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 6d8e7fcc2..658e04c34 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -60,10 +60,12 @@ WorkerPool::WorkerPool(boost::asio::io_service &io_service, int num_workers, int maximum_startup_concurrency, std::shared_ptr gcs_client, const WorkerCommandMap &worker_commands, + const std::unordered_map &raylet_config, std::function starting_worker_timeout_callback) : io_service_(&io_service), maximum_startup_concurrency_(maximum_startup_concurrency), gcs_client_(std::move(gcs_client)), + raylet_config_(raylet_config), starting_worker_timeout_callback_(starting_worker_timeout_callback) { RAY_CHECK(maximum_startup_concurrency > 0); #ifndef _WIN32 @@ -173,7 +175,7 @@ Process WorkerPool::StartWorkerProcess(const Language &language, // Extract pointers from the worker command to pass into execvp. std::vector worker_command_args; size_t dynamic_option_index = 0; - bool num_workers_arg_replaced = false; + bool worker_raylet_config_placeholder_found = false; for (auto const &token : state.worker_command) { const auto option_placeholder = kWorkerDynamicOptionPlaceholderPrefix + std::to_string(dynamic_option_index); @@ -186,23 +188,48 @@ Process WorkerPool::StartWorkerProcess(const Language &language, options.end()); ++dynamic_option_index; } - } else { - size_t num_workers_index = token.find(kWorkerNumWorkersPlaceholder); - if (num_workers_index != std::string::npos) { - std::string arg = token; - worker_command_args.push_back(arg.replace(num_workers_index, - strlen(kWorkerNumWorkersPlaceholder), - std::to_string(workers_to_start))); - num_workers_arg_replaced = true; - } else { - worker_command_args.push_back(token); - } + continue; } + + if (token == kWorkerRayletConfigPlaceholder) { + worker_raylet_config_placeholder_found = true; + switch (language) { + case Language::JAVA: + for (auto &entry : raylet_config_) { + if (entry.first == "num_workers_per_process_java") { + continue; + } + std::string arg; + arg.append("-Dray.raylet.config."); + arg.append(entry.first); + arg.append("="); + arg.append(entry.second); + worker_command_args.push_back(arg); + } + // The value of `num_workers_per_process_java` may change depends on whether + // dynamic options is empty, so we can't use the value in `RayConfig`. We always + // overwrite the value here. + worker_command_args.push_back( + "-Dray.raylet.config.num_workers_per_process_java=" + + std::to_string(workers_to_start)); + break; + default: + RAY_LOG(FATAL) + << "Raylet config placeholder is not supported for worker language " + << language; + } + continue; + } + + worker_command_args.push_back(token); + } + + // Currently only Java worker process supports multi-worker. + if (language == Language::JAVA) { + RAY_CHECK(worker_raylet_config_placeholder_found) + << "The " << kWorkerRayletConfigPlaceholder + << " placeholder is not found in worker command."; } - RAY_CHECK(num_workers_arg_replaced || state.num_workers_per_process == 1) - << "Expect to start " << state.num_workers_per_process << " workers per " - << Language_Name(language) << " worker process. But the " - << kWorkerNumWorkersPlaceholder << "placeholder is not found in worker command."; Process proc = StartProcess(worker_command_args); RAY_LOG(DEBUG) << "Started worker process of " << workers_to_start diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index 3a1a55235..ec2a98b20 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -56,11 +56,13 @@ class WorkerPool { /// resources on the machine). /// \param worker_commands The commands used to start the worker process, grouped by /// language. + /// \param raylet_config The raylet config list of this node. /// \param starting_worker_timeout_callback The callback that will be triggered once /// it times out to start a worker. WorkerPool(boost::asio::io_service &io_service, int num_workers, int maximum_startup_concurrency, std::shared_ptr gcs_client, const WorkerCommandMap &worker_commands, + const std::unordered_map &raylet_config, std::function starting_worker_timeout_callback); /// Destructor responsible for freeing a set of workers owned by this class. @@ -252,6 +254,8 @@ class WorkerPool { int maximum_startup_concurrency_; /// A client connection to the GCS. std::shared_ptr gcs_client_; + /// The raylet config list of this node. + std::unordered_map raylet_config_; /// The callback that will be triggered once it times out to start a worker. std::function starting_worker_timeout_callback_; FRIEND_TEST(WorkerPoolTest, InitialWorkerProcessCount); diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index c77fcd3f4..7c5a44392 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -24,7 +24,7 @@ namespace ray { namespace raylet { -int NUM_WORKERS_PER_PROCESS = 3; +int NUM_WORKERS_PER_PROCESS_JAVA = 3; int MAXIMUM_STARTUP_CONCURRENCY = 5; std::vector LANGUAGES = {Language::PYTHON, Language::JAVA}; @@ -34,20 +34,17 @@ class WorkerPoolMock : public WorkerPool { WorkerPoolMock(boost::asio::io_service &io_service) : WorkerPoolMock( io_service, - {{Language::PYTHON, - {"dummy_py_worker_command", "--foo=RAY_WORKER_NUM_WORKERS_PLACEHOLDER"}}, + {{Language::PYTHON, {"dummy_py_worker_command"}}, {Language::JAVA, - {"dummy_java_worker_command", - "--foo=RAY_WORKER_NUM_WORKERS_PLACEHOLDER"}}}) {} + {"dummy_java_worker_command", "RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER"}}}) {} explicit WorkerPoolMock(boost::asio::io_service &io_service, const WorkerCommandMap &worker_commands) : WorkerPool(io_service, 0, MAXIMUM_STARTUP_CONCURRENCY, nullptr, worker_commands, - []() {}), + {}, []() {}), last_worker_process_() { - for (auto &entry : states_by_lang_) { - entry.second.num_workers_per_process = NUM_WORKERS_PER_PROCESS; - } + states_by_lang_[ray::Language::JAVA].num_workers_per_process = + NUM_WORKERS_PER_PROCESS_JAVA; } ~WorkerPoolMock() { @@ -130,6 +127,34 @@ class WorkerPoolTest : public ::testing::Test { this->worker_pool_ = std::move(worker_pool); } + void TestStartupWorkerProcessCount(Language language, int num_workers_per_process, + std::vector expected_worker_command) { + int desired_initial_worker_process_count = 100; + int expected_worker_process_count = static_cast(std::ceil( + static_cast(MAXIMUM_STARTUP_CONCURRENCY) / num_workers_per_process)); + ASSERT_TRUE(expected_worker_process_count < + static_cast(desired_initial_worker_process_count)); + Process last_started_worker_process; + for (int i = 0; i < desired_initial_worker_process_count; i++) { + worker_pool_.StartWorkerProcess(language); + ASSERT_TRUE(worker_pool_.NumWorkerProcessesStarting() <= + expected_worker_process_count); + Process prev = worker_pool_.LastStartedWorkerProcess(); + if (!std::equal_to()(last_started_worker_process, prev)) { + last_started_worker_process = prev; + const auto &real_command = + worker_pool_.GetWorkerCommand(last_started_worker_process); + ASSERT_EQ(real_command, expected_worker_command); + } else { + ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), + expected_worker_process_count); + ASSERT_TRUE(i >= expected_worker_process_count); + } + } + // Check number of starting workers + ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), expected_worker_process_count); + } + protected: boost::asio::io_service io_service_; WorkerPoolMock worker_pool_; @@ -177,10 +202,10 @@ TEST_F(WorkerPoolTest, CompareWorkerProcessObjects) { } TEST_F(WorkerPoolTest, HandleWorkerRegistration) { - Process proc = worker_pool_.StartWorkerProcess(Language::PYTHON); + Process proc = worker_pool_.StartWorkerProcess(Language::JAVA); std::vector> workers; - for (int i = 0; i < NUM_WORKERS_PER_PROCESS; i++) { - workers.push_back(CreateWorker(Process())); + for (int i = 0; i < NUM_WORKERS_PER_PROCESS_JAVA; i++) { + workers.push_back(CreateWorker(Process(), Language::JAVA)); } for (const auto &worker : workers) { // Check that there's still a starting worker process @@ -201,51 +226,26 @@ TEST_F(WorkerPoolTest, HandleWorkerRegistration) { } } -TEST_F(WorkerPoolTest, StartupWorkerProcessCount) { - std::string num_workers_arg = - std::string("--foo=") + std::to_string(NUM_WORKERS_PER_PROCESS); - std::vector> worker_commands = { - {{"dummy_py_worker_command", num_workers_arg}, - {"dummy_java_worker_command", num_workers_arg}}}; - int desired_initial_worker_process_count_per_language = 100; - int expected_worker_process_count = - static_cast(std::ceil(static_cast(MAXIMUM_STARTUP_CONCURRENCY) / - NUM_WORKERS_PER_PROCESS * LANGUAGES.size())); - ASSERT_TRUE(expected_worker_process_count < - static_cast(desired_initial_worker_process_count_per_language * - LANGUAGES.size())); - Process last_started_worker_process; - for (int i = 0; i < desired_initial_worker_process_count_per_language; i++) { - for (size_t j = 0; j < LANGUAGES.size(); j++) { - worker_pool_.StartWorkerProcess(LANGUAGES[j]); - ASSERT_TRUE(worker_pool_.NumWorkerProcessesStarting() <= - expected_worker_process_count); - Process prev = worker_pool_.LastStartedWorkerProcess(); - if (!std::equal_to()(last_started_worker_process, prev)) { - last_started_worker_process = prev; - const auto &real_command = - worker_pool_.GetWorkerCommand(worker_pool_.LastStartedWorkerProcess()); - ASSERT_EQ(real_command, worker_commands[j]); - } else { - ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), - expected_worker_process_count); - ASSERT_TRUE(static_cast(i * LANGUAGES.size() + j) >= - expected_worker_process_count); - } - } - } - // Check number of starting workers - ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), expected_worker_process_count); +TEST_F(WorkerPoolTest, StartupPythonWorkerProcessCount) { + TestStartupWorkerProcessCount(Language::PYTHON, 1, {"dummy_py_worker_command"}); +} + +TEST_F(WorkerPoolTest, StartupJavaWorkerProcessCount) { + TestStartupWorkerProcessCount( + Language::JAVA, NUM_WORKERS_PER_PROCESS_JAVA, + {"dummy_java_worker_command", + std::string("-Dray.raylet.config.num_workers_per_process_java=") + + std::to_string(NUM_WORKERS_PER_PROCESS_JAVA)}); } TEST_F(WorkerPoolTest, InitialWorkerProcessCount) { worker_pool_.Start(1); - // Here we try to start only 1 worker for each worker language. But since each worker - // process contains exactly NUM_WORKERS_PER_PROCESS (3) workers here, it's expected to - // see 3 workers for each worker language, instead of 1. + // Here we try to start only 1 worker for each worker language. But since each Java + // worker process contains exactly NUM_WORKERS_PER_PROCESS_JAVA (3) workers here, + // it's expected to see 3 workers for Java and 1 worker for Python, instead of 1 for + // each worker language. ASSERT_NE(worker_pool_.NumWorkersStarting(), 1 * LANGUAGES.size()); - ASSERT_EQ(worker_pool_.NumWorkersStarting(), - NUM_WORKERS_PER_PROCESS * LANGUAGES.size()); + ASSERT_EQ(worker_pool_.NumWorkersStarting(), 1 + NUM_WORKERS_PER_PROCESS_JAVA); ASSERT_EQ(worker_pool_.NumWorkerProcessesStarting(), LANGUAGES.size()); } @@ -320,8 +320,7 @@ TEST_F(WorkerPoolTest, PopWorkersOfMultipleLanguages) { TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { const std::vector java_worker_command = { "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_0", "dummy_java_worker_command", - "--foo=RAY_WORKER_NUM_WORKERS_PLACEHOLDER", - "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_1"}; + "RAY_WORKER_RAYLET_CONFIG_PLACEHOLDER", "RAY_WORKER_DYNAMIC_OPTION_PLACEHOLDER_1"}; SetWorkerCommands({{Language::PYTHON, {"dummy_py_worker_command"}}, {Language::JAVA, java_worker_command}}); @@ -334,7 +333,8 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { worker_pool_.GetWorkerCommand(worker_pool_.LastStartedWorkerProcess()); ASSERT_EQ(real_command, std::vector( - {"test_op_0", "dummy_java_worker_command", "--foo=1", "test_op_1"})); + {"test_op_0", "dummy_java_worker_command", + "-Dray.raylet.config.num_workers_per_process_java=1", "test_op_1"})); } } // namespace raylet