diff --git a/python/ray/_private/services.py b/python/ray/_private/services.py index 6d8bbb97c..016b2bf6a 100644 --- a/python/ray/_private/services.py +++ b/python/ray/_private/services.py @@ -1811,3 +1811,38 @@ def start_monitor(redis_address, stderr_file=stderr_file, fate_share=fate_share) return process_info + + +def start_ray_client_server(redis_address, + ray_client_server_port, + stdout_file=None, + stderr_file=None, + redis_password=None, + fate_share=None): + """Run the server process of the Ray client. + + Args: + ray_client_server_port (int): Port the Ray client server listens on. + stdout_file: A file handle opened for writing to redirect stdout to. If + no redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If + no redirection should happen, then this should be None. + redis_password (str): The password of the redis server. + + Returns: + ProcessInfo for the process that was started. + """ + command = [ + sys.executable, "-m", "ray.util.client.server", + "--redis-address=" + str(redis_address), + "--port=" + str(ray_client_server_port) + ] + if redis_password: + command.append("--redis-password=" + redis_password) + process_info = start_ray_process( + command, + ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER, + stdout_file=stdout_file, + stderr_file=stderr_file, + fate_share=fate_share) + return process_info diff --git a/python/ray/node.py b/python/ray/node.py index a7ec72e7a..3eb1fcd0d 100644 --- a/python/ray/node.py +++ b/python/ray/node.py @@ -752,6 +752,23 @@ class Node: assert ray_constants.PROCESS_TYPE_MONITOR not in self.all_processes self.all_processes[ray_constants.PROCESS_TYPE_MONITOR] = [process_info] + def start_ray_client_server(self): + """Start the ray client server process.""" + stdout_file, stderr_file = self.get_log_file_handles( + "ray_client_server", unique=True) + process_info = ray._private.services.start_ray_client_server( + self._redis_address, + self._ray_params.ray_client_server_port, + stdout_file=stdout_file, + stderr_file=stderr_file, + redis_password=self._ray_params.redis_password, + fate_share=self.kernel_fate_share) + assert (ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER not in + self.all_processes) + self.all_processes[ray_constants.PROCESS_TYPE_RAY_CLIENT_SERVER] = [ + process_info + ] + def start_head_processes(self): """Start head processes on the node.""" logger.debug(f"Process STDOUT and STDERR is being " @@ -764,6 +781,9 @@ class Node: self.start_monitor() + if self._ray_params.ray_client_server_port: + self.start_ray_client_server() + if self._ray_params.include_dashboard: self.start_dashboard(require_dashboard=True) elif self._ray_params.include_dashboard is None: diff --git a/python/ray/parameter.py b/python/ray/parameter.py index f6bbc243c..c31e09df1 100644 --- a/python/ray/parameter.py +++ b/python/ray/parameter.py @@ -45,6 +45,9 @@ class RayParams: worker_port_list (str): An explicit list of ports to be used for workers (comma-separated). Overrides min_worker_port and max_worker_port. + ray_client_server_port (int): The port number the ray client server + will bind on. If not set, the ray client server will not + be started. object_ref_seed (int): Used to seed the deterministic generation of object refs. The same value can be used across multiple runs of the same job in order to generate the object refs in a consistent @@ -120,6 +123,7 @@ class RayParams: min_worker_port=None, max_worker_port=None, worker_port_list=None, + ray_client_server_port=None, object_ref_seed=None, driver_mode=None, redirect_worker_output=None, @@ -165,6 +169,7 @@ class RayParams: self.min_worker_port = min_worker_port self.max_worker_port = max_worker_port self.worker_port_list = worker_port_list + self.ray_client_server_port = ray_client_server_port self.driver_mode = driver_mode self.redirect_worker_output = redirect_worker_output self.redirect_output = redirect_output diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index 30b3b5c7b..a5459b863 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -160,6 +160,7 @@ LOGGING_ROTATE_BACKUP_COUNT = 50 # backupCount # Constants used to define the different process types. PROCESS_TYPE_REAPER = "reaper" PROCESS_TYPE_MONITOR = "monitor" +PROCESS_TYPE_RAY_CLIENT_SERVER = "ray_client_server" PROCESS_TYPE_LOG_MONITOR = "log_monitor" # TODO(sang): Delete it. PROCESS_TYPE_REPORTER = "reporter" diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 4a1dd6e28..11941f87a 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -278,6 +278,13 @@ def debug(address): required=False, help="a comma-separated list of open ports for workers to bind on. " "Overrides '--min-worker-port' and '--max-worker-port'.") +@click.option( + "--ray-client-server-port", + required=False, + type=int, + default=None, + help="the port number the ray client server will bind on. If not set, " + "the ray client server will not be started.") @click.option( "--memory", required=False, @@ -415,9 +422,10 @@ def debug(address): @add_click_options(logging_options) def start(node_ip_address, address, port, redis_password, redis_shard_ports, object_manager_port, node_manager_port, gcs_server_port, - min_worker_port, max_worker_port, worker_port_list, memory, - object_store_memory, redis_max_memory, num_cpus, num_gpus, resources, - head, include_dashboard, dashboard_host, dashboard_port, block, + min_worker_port, max_worker_port, worker_port_list, + ray_client_server_port, memory, object_store_memory, + redis_max_memory, num_cpus, num_gpus, resources, head, + include_dashboard, dashboard_host, dashboard_port, block, plasma_directory, autoscaling_config, no_redirect_worker_output, no_redirect_output, plasma_store_socket_name, raylet_socket_name, temp_dir, java_worker_options, system_config, lru_evict, @@ -459,6 +467,7 @@ def start(node_ip_address, address, port, redis_password, redis_shard_ports, min_worker_port=min_worker_port, max_worker_port=max_worker_port, worker_port_list=worker_port_list, + ray_client_server_port=ray_client_server_port, object_manager_port=object_manager_port, node_manager_port=node_manager_port, gcs_server_port=gcs_server_port, @@ -698,6 +707,7 @@ def stop(force, verbose, log_style, log_color): ["plasma_store", True], ["gcs_server", True], ["monitor.py", False], + ["ray.util.client.server", False], ["redis-server", False], ["default_worker.py", False], # Python worker. ["ray::", True], # Python worker. TODO(mehrdadn): Fix for Windows diff --git a/python/ray/tests/test_multi_node_2.py b/python/ray/tests/test_multi_node_2.py index 6578bdeb9..705b81589 100644 --- a/python/ray/tests/test_multi_node_2.py +++ b/python/ray/tests/test_multi_node_2.py @@ -186,6 +186,23 @@ def test_wait_for_nodes(ray_start_cluster_head): assert ray.cluster_resources()["CPU"] == 1 +@pytest.mark.parametrize( + "call_ray_start", [ + "ray start --head --ray-client-server-port 20000 " + + "--min-worker-port=0 --max-worker-port=0 --port 0" + ], + indirect=True) +def test_ray_client(call_ray_start): + from ray.util.client import ray + ray.connect("localhost:20000") + + @ray.remote + def f(): + return "hello client" + + assert ray.get(f.remote()) == "hello client" + + if __name__ == "__main__": import pytest import sys diff --git a/python/ray/util/client/dataclient.py b/python/ray/util/client/dataclient.py index add66c82c..38f095f3f 100644 --- a/python/ray/util/client/dataclient.py +++ b/python/ray/util/client/dataclient.py @@ -52,7 +52,8 @@ class DataClient: stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel) resp_stream = stub.Datapath( iter(self.request_queue.get, None), - metadata=(("client_id", self._client_id), )) + metadata=(("client_id", self._client_id), ), + wait_for_ready=True) try: for response in resp_stream: if response.req_id == 0: diff --git a/python/ray/util/client/server/__init__.py b/python/ray/util/client/server/__init__.py index e69de29bb..37c7767bb 100644 --- a/python/ray/util/client/server/__init__.py +++ b/python/ray/util/client/server/__init__.py @@ -0,0 +1 @@ +from ray.util.client.server.server import serve # noqa diff --git a/python/ray/util/client/server/__main__.py b/python/ray/util/client/server/__main__.py new file mode 100644 index 000000000..c31fd5eb4 --- /dev/null +++ b/python/ray/util/client/server/__main__.py @@ -0,0 +1,3 @@ +if __name__ == "__main__": + from ray.util.client.server.server import main + main() diff --git a/python/ray/util/client/server/server.py b/python/ray/util/client/server/server.py index c065b9c66..82c76a7c9 100644 --- a/python/ray/util/client/server/server.py +++ b/python/ray/util/client/server/server.py @@ -418,13 +418,39 @@ def shutdown_with_server(server, _exiting_interpreter=False): ray.shutdown(_exiting_interpreter) -if __name__ == "__main__": +def main(): + import argparse + parser = argparse.ArgumentParser() + parser.add_argument( + "--host", type=str, default="0.0.0.0", help="Host IP to bind to") + parser.add_argument( + "-p", "--port", type=int, default=50051, help="Port to bind to") + parser.add_argument( + "--redis-address", + required=True, + type=str, + help="Address to use to connect to Ray") + parser.add_argument( + "--redis-password", + required=False, + type=str, + help="Password for connecting to Redis") + args = parser.parse_args() logging.basicConfig(level="INFO") - # TODO(barakmich): Perhaps wrap ray init - ray.init() - server = serve("0.0.0.0:50051") + if args.redis_password: + ray.init( + address=args.redis_address, _redis_password=args.redis_password) + else: + ray.init(address=args.redis_address) + hostport = "%s:%d" % (args.host, args.port) + logger.info(f"Starting Ray Client server on {hostport}") + server = serve(hostport) try: while True: time.sleep(1000) except KeyboardInterrupt: server.stop(0) + + +if __name__ == "__main__": + main()