diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index b5578404d..8e4fd8059 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -77,4 +77,38 @@ test + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + copy-dependencies-to-build + package + + copy-dependencies + + + ${basedir}/../../build/java + false + false + true + + + + + + + org.apache.maven.plugins + maven-jar-plugin + 2.3.1 + + ${basedir}/../../build/java + + + + + diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index 8374ffdcc..62aba7a45 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -71,6 +71,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { } redisClient = new RedisClient(rayConfig.getRedisAddress()); + // TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis. objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName); rayletClient = new RayletClientImpl( diff --git a/java/test/src/main/java/org/ray/api/test/ActorTest.java b/java/test/src/main/java/org/ray/api/test/ActorTest.java index 2c47d84c3..e0fc595e1 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorTest.java @@ -4,7 +4,6 @@ import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.annotation.RayRemote; -import org.ray.api.function.RayFunc2; import org.ray.api.id.UniqueId; import org.ray.runtime.RayActorImpl; import org.testng.Assert; diff --git a/java/test/src/main/java/org/ray/api/test/MultiLanguageClusterTest.java b/java/test/src/main/java/org/ray/api/test/MultiLanguageClusterTest.java new file mode 100644 index 000000000..d487acc03 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/MultiLanguageClusterTest.java @@ -0,0 +1,115 @@ +package org.ray.api.test; + +import com.google.common.collect.ImmutableList; +import java.io.File; +import java.io.IOException; +import java.lang.ProcessBuilder.Redirect; +import java.util.List; +import java.util.concurrent.TimeUnit; +import org.ray.api.Ray; +import org.ray.api.RayObject; +import org.ray.api.annotation.RayRemote; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.SkipException; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +/** + * Test starting a ray cluster with multi-language support. + */ +public class MultiLanguageClusterTest { + + private static final Logger LOGGER = LoggerFactory.getLogger(MultiLanguageClusterTest.class); + + private static final String PLASMA_STORE_SOCKET_NAME = "/tmp/ray/test/plasma_store_socket"; + private static final String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket"; + + @RayRemote + public static String echo(String word) { + return word; + } + + /** + * Execute an external command. + * @return Whether the command succeeded. + */ + private boolean executeCommand(List command, int waitTimeoutSeconds) { + try { + LOGGER.info("Executing command: {}", String.join(" ", command)); + Process process = new ProcessBuilder(command).redirectOutput(Redirect.INHERIT) + .redirectError(Redirect.INHERIT).start(); + process.waitFor(waitTimeoutSeconds, TimeUnit.SECONDS); + return process.exitValue() == 0; + } catch (Exception e) { + throw new RuntimeException("Error executing command " + String.join(" ", command), e); + } + } + + @BeforeMethod + public void setUp() { + // Check whether 'ray' command is installed. + boolean rayCommandExists = executeCommand(ImmutableList.of("which", "ray"), 5); + if (!rayCommandExists) { + throw new SkipException("Skipping test, because ray command doesn't exist."); + } + + // Delete existing socket files. + for (String socket : ImmutableList.of(RAYLET_SOCKET_NAME, PLASMA_STORE_SOCKET_NAME)) { + File file = new File(socket); + if (file.exists()) { + file.delete(); + } + } + + // Start ray cluster. + final List startCommand = ImmutableList.of( + "ray", + "start", + "--head", + "--redis-port=6379", + "--include-java", + String.format("--plasma-store-socket-name=%s", PLASMA_STORE_SOCKET_NAME), + String.format("--raylet-socket-name=%s", RAYLET_SOCKET_NAME), + "--java-worker-options=-classpath ../../build/java/*:../../java/test/target/*" + ); + if (!executeCommand(startCommand, 10)) { + throw new RuntimeException("Couldn't start ray cluster."); + } + + // Connect to the cluster. + System.setProperty("ray.home", "../.."); + System.setProperty("ray.redis.address", "127.0.0.1:6379"); + System.setProperty("ray.object-store.socket-name", PLASMA_STORE_SOCKET_NAME); + System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME); + Ray.init(); + } + + @AfterMethod + public void tearDown() { + // Disconnect to the cluster. + Ray.shutdown(); + System.clearProperty("ray.home"); + System.clearProperty("ray.redis.address"); + System.clearProperty("ray.object-store.socket-name"); + System.clearProperty("ray.raylet.socket-name"); + + // Stop ray cluster. + final List stopCommand = ImmutableList.of( + "ray", + "stop" + ); + if (!executeCommand(stopCommand, 10)) { + throw new RuntimeException("Couldn't stop ray cluster"); + } + } + + @Test + public void testMultiLanguageCluster() { + RayObject obj = Ray.call(MultiLanguageClusterTest::echo, "hello"); + Assert.assertEquals("hello", obj.get()); + } + +} diff --git a/python/ray/node.py b/python/ray/node.py index 7f6e99bad..fa722368b 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -62,6 +62,11 @@ class Node(object): if head: ray_params.update_if_absent(num_redis_shards=1, include_webui=True) + else: + redis_client = ray.services.create_redis_client( + ray_params.redis_address, ray_params.redis_password) + ray_params.include_java = ( + ray.services.include_java_from_redis(redis_client)) self._ray_params = ray_params self._config = (json.loads(ray_params._internal_config) @@ -224,7 +229,10 @@ class Node(object): use_profiler=use_profiler, stdout_file=stdout_file, stderr_file=stderr_file, - config=self._config) + config=self._config, + include_java=self._ray_params.include_java, + java_worker_options=self._ray_params.java_worker_options, + ) assert ray_constants.PROCESS_TYPE_RAYLET not in self.all_processes self.all_processes[ray_constants.PROCESS_TYPE_RAYLET] = [process_info] diff --git a/python/ray/parameter.py b/python/ray/parameter.py index f24bdfd6f..5b4a0cb48 100644 --- a/python/ray/parameter.py +++ b/python/ray/parameter.py @@ -70,6 +70,9 @@ class RayParams(object): monitor the log files for all processes on this node and push their contents to Redis. autoscaling_config: path to autoscaling config file. + include_java (bool): If True, the raylet backend can also support + Java worker. + java_worker_options (str): The command options for Java worker. _internal_config (str): JSON configuration for overriding RayConfig defaults. For testing purposes ONLY. """ @@ -106,6 +109,8 @@ class RayParams(object): temp_dir=None, include_log_monitor=None, autoscaling_config=None, + include_java=False, + java_worker_options=None, _internal_config=None): self.object_id_seed = object_id_seed self.redis_address = redis_address @@ -136,6 +141,8 @@ class RayParams(object): self.temp_dir = temp_dir self.include_log_monitor = include_log_monitor self.autoscaling_config = autoscaling_config + self.include_java = include_java + self.java_worker_options = java_worker_options self._internal_config = _internal_config self._check_usage() @@ -146,7 +153,7 @@ class RayParams(object): kwargs: The keyword arguments to set corresponding fields. """ for arg in kwargs: - if (hasattr(self, arg)): + if hasattr(self, arg): setattr(self, arg, kwargs[arg]) else: raise ValueError("Invalid RayParams parameter in" @@ -161,7 +168,7 @@ class RayParams(object): kwargs: The keyword arguments to set corresponding fields. """ for arg in kwargs: - if (hasattr(self, arg)): + if hasattr(self, arg): if getattr(self, arg) is None: setattr(self, arg, kwargs[arg]) else: @@ -180,6 +187,10 @@ class RayParams(object): "num_gpus instead.") if self.num_workers is not None: - raise Exception( + raise ValueError( "The 'num_workers' argument is deprecated. Please use " "'num_cpus' instead.") + + if self.include_java is None and self.java_worker_options is not None: + raise ValueError("Should not specify `java-worker-options` " + "without providing `include-java`.") diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index dbe9e0104..3ec867dbe 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -201,6 +201,17 @@ def cli(logging_level, logging_format): "--temp-dir", default=None, help="manually specify the root temporary dir of the Ray process") +@click.option( + "--include-java", + is_flag=True, + default=None, + help="Enable Java worker support.") +@click.option( + "--java-worker-options", + required=False, + default=None, + type=str, + help="Overwrite the options to start Java workers.") @click.option( "--internal-config", default=None, @@ -212,8 +223,8 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, redis_max_memory, num_workers, num_cpus, num_gpus, resources, head, no_ui, block, plasma_directory, huge_pages, autoscaling_config, no_redirect_worker_output, no_redirect_output, - plasma_store_socket_name, raylet_socket_name, temp_dir, - internal_config): + plasma_store_socket_name, raylet_socket_name, temp_dir, include_java, + java_worker_options, internal_config): # Convert hostnames to numerical IP address. if node_ip_address is not None: node_ip_address = services.address_to_ip(node_ip_address) @@ -245,6 +256,8 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, plasma_store_socket_name=plasma_store_socket_name, raylet_socket_name=raylet_socket_name, temp_dir=temp_dir, + include_java=include_java, + java_worker_options=java_worker_options, _internal_config=internal_config) if head: @@ -280,7 +293,9 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, include_webui=(not no_ui), - autoscaling_config=autoscaling_config) + autoscaling_config=autoscaling_config, + include_java=False, + ) node = ray.node.Node(ray_params, head=True, shutdown_at_exit=False) redis_address = node.redis_address @@ -322,6 +337,10 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, if no_ui: raise Exception("If --head is not passed in, the --no-ui flag is " "not relevant.") + if include_java is not None: + raise ValueError("--include-java should only be set for the head " + "node.") + redis_ip_address, redis_port = redis_address.split(":") # Wait for the Redis server to be started. And throw an exception if we @@ -348,7 +367,6 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, check_no_existing_redis_clients(ray_params.node_ip_address, redis_client) ray_params.update(redis_address=redis_address) - node = ray.node.Node(ray_params, head=False, shutdown_at_exit=False) logger.info("\nStarted Ray on this node. If you wish to terminate the " "processes that have been started, run\n\n" diff --git a/python/ray/services.py b/python/ray/services.py index adec81aa0..c9eb23058 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -21,14 +21,19 @@ import pyarrow import ray import ray.ray_constants as ray_constants -from ray.tempfile_services import (get_ipython_notebook_path, get_temp_root, - new_redis_log_file) +from ray.tempfile_services import ( + get_ipython_notebook_path, + get_logs_dir_path, + get_temp_root, + new_redis_log_file, +) # True if processes are run in the valgrind profiler. RUN_RAYLET_PROFILER = False RUN_PLASMA_STORE_PROFILER = False # Location of the redis server and module. +RAY_HOME = os.path.join(os.path.dirname(__file__), "../..") REDIS_EXECUTABLE = os.path.join( os.path.abspath(os.path.dirname(__file__)), "core/src/ray/thirdparty/redis/src/redis-server") @@ -60,6 +65,10 @@ RAYLET_MONITOR_EXECUTABLE = os.path.join( RAYLET_EXECUTABLE = os.path.join( os.path.abspath(os.path.dirname(__file__)), "core/src/ray/raylet/raylet") +DEFAULT_JAVA_WORKER_OPTIONS = "-classpath {}".format( + os.path.join( + os.path.abspath(os.path.dirname(__file__)), "../../../build/java/*")) + # Logger for this module. It should be configured at the entry point # into the program using Ray. Ray provides a default configuration at # entry/init points. @@ -93,6 +102,18 @@ def new_port(): return random.randint(10000, 65535) +def include_java_from_redis(redis_client): + """This is used for query include_java bool from redis. + + Args: + redis_client (StrictRedis): The redis client to GCS. + + Returns: + True if this cluster backend enables Java worker. + """ + return redis_client.get("INCLUDE_JAVA") == b"1" + + def remaining_processes_alive(): """See if the remaining processes are alive or not. @@ -249,8 +270,8 @@ def start_ray_process(command, no redirection should happen, then this should be None. Returns: - Inormation about the process that was started including a handle to the - process that was started. + Information about the process that was started including a handle to + the process that was started. """ # Detect which flags are set through environment variables. valgrind_env_var = "RAY_{}_VALGRIND".format(process_type.upper()) @@ -451,7 +472,8 @@ def start_redis(node_ip_address, redirect_worker_output=False, password=None, use_credis=None, - redis_max_memory=None): + redis_max_memory=None, + include_java=False): """Start the Redis global state store. Args: @@ -481,6 +503,8 @@ def start_redis(node_ip_address, LRU eviction of entries. This only applies to the sharded redis tables (task, object, and profile tables). By default, this is capped at 10GB but can be set higher. + include_java (bool): If True, the raylet backend can also support + Java worker. Returns: A tuple of the address for the primary Redis shard, a list of @@ -555,6 +579,10 @@ def start_redis(node_ip_address, primary_redis_client.set("RedirectOutput", 1 if redirect_worker_output else 0) + # put the include_java bool to primary redis-server, so that other nodes + # can access it and know whether or not to enable cross-languages. + primary_redis_client.set("INCLUDE_JAVA", 1 if include_java else 0) + # Store version information in the primary Redis shard. _put_version_info_in_redis(primary_redis_client) @@ -960,7 +988,9 @@ def start_raylet(redis_address, use_profiler=False, stdout_file=None, stderr_file=None, - config=None): + config=None, + include_java=False, + java_worker_options=None): """Start a raylet, which is a combined local scheduler and object manager. Args: @@ -989,7 +1019,9 @@ def start_raylet(redis_address, no redirection should happen, then this should be None. config (dict|None): Optional Raylet configuration that will override defaults in RayConfig. - + include_java (bool): If True, the raylet backend can also support + Java worker. + java_worker_options (str): The command options for Java worker. Returns: ProcessInfo for the process that was started. """ @@ -1016,6 +1048,14 @@ def start_raylet(redis_address, gcs_ip_address, gcs_port = redis_address.split(":") + if include_java is True: + java_worker_options = (java_worker_options + or DEFAULT_JAVA_WORKER_OPTIONS) + java_worker_command = build_java_worker_command( + java_worker_options, redis_address, plasma_store_name, raylet_name) + else: + java_worker_command = "" + # Create the command that the Raylet will use to start workers. start_worker_command = ("{} {} " "--node-ip-address={} " @@ -1052,7 +1092,7 @@ def start_raylet(redis_address, resource_argument, config_str, start_worker_command, - "", # Worker command for Java, not needed for Python. + java_worker_command, redis_password or "", get_temp_root(), ] @@ -1073,6 +1113,40 @@ def start_raylet(redis_address, return process_info +def build_java_worker_command(java_worker_options, redis_address, + plasma_store_name, raylet_name): + """This method assembles the command used to start a Java worker. + + Args: + java_worker_options (str): The command options for Java worker. + redis_address (str): Redis address of GCS. + plasma_store_name (str): The name of the plasma store socket to connect + to. + raylet_name (str): The name of the raylet socket to create. + + Returns: + The command string for starting Java worker. + """ + assert java_worker_options is not None + + command = "java {} ".format(java_worker_options) + if redis_address is not None: + command += "-Dray.redis.address={} ".format(redis_address) + + if plasma_store_name is not None: + command += ( + "-Dray.object-store.socket-name={} ".format(plasma_store_name)) + + if raylet_name is not None: + command += "-Dray.raylet.socket-name={} ".format(raylet_name) + + command += "-Dray.home={} ".format(RAY_HOME) + command += "-Dray.log-dir={} ".format(get_logs_dir_path()) + command += "org.ray.runtime.runner.worker.DefaultWorker" + + return command + + def determine_plasma_store_config(object_store_memory=None, plasma_directory=None, huge_pages=False):