From e06311d4151d8e5050f922cc0c8e8b5fb0e07435 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Tue, 16 Aug 2016 14:53:55 -0700 Subject: [PATCH] Automatically add relevant directories to Python paths of workers (#380) * Make ray.init set python paths of workers. * Decouple starting cluster from copying user source code * also add current directory to path * Add comments about deallocation. * Add test for new code path. --- doc/using-ray-on-a-cluster.md | 21 +++++----- lib/python/ray/services.py | 16 ++------ lib/python/ray/worker.py | 48 +++++++++++++++++++++++ protos/ray.proto | 12 ++++++ scripts/cluster.py | 30 +++++---------- scripts/default_worker.py | 7 ---- src/raylib.cc | 18 ++++++++- src/scheduler.cc | 72 +++++++++++++++++++++++++++++------ src/scheduler.h | 16 ++++++++ src/worker.cc | 26 +++++++++++++ src/worker.h | 3 ++ test/runtest.py | 17 +++++++++ 12 files changed, 222 insertions(+), 64 deletions(-) diff --git a/doc/using-ray-on-a-cluster.md b/doc/using-ray-on-a-cluster.md index d6502a106..623534061 100644 --- a/doc/using-ray-on-a-cluster.md +++ b/doc/using-ray-on-a-cluster.md @@ -127,14 +127,9 @@ the tests. python test/array_test.py # This tests some array libraries. ``` -6. Start the cluster with `cluster.start_ray()`. If you would like to deploy -source code to it, you can pass in the local path to the directory that contains -your Python code. For example, `cluster.start_ray("~/example_ray_code")`. This -will copy your source code to each node on the cluster, placing it in a -directory on the PYTHONPATH. -The `cluster.start_ray` command will start the Ray scheduler, object stores, and -workers, and before finishing it will print instructions for connecting to the -cluster via ssh. +6. Start the cluster with `cluster.start_ray()`. The `cluster.start_ray` command +will start the Ray scheduler, object stores, and workers, and before finishing +it will print instructions for connecting to the cluster via ssh. 7. To connect to the cluster (either with a Python shell or with a script), ssh to the cluster's head node (as described by the output of the @@ -146,7 +141,6 @@ to the cluster's head node (as described by the output of the Then run the following commands. - cd $HOME/ray source $HOME/ray/setup-env.sh # Add Ray to your Python path. Then within a Python interpreter, run the following commands. @@ -177,11 +171,14 @@ need to install a few more Python packages. This can be done, within - `cluster.install_ray()` - This pulls the Ray source code on each node, builds all of the third party libraries, and builds the project itself. - - `cluster.start_ray(user_source_directory=None, num_workers_per_node=10)` - - This starts a scheduler process on the head node, and it starts an object - store and some workers on each node. + - `cluster.start_ray(num_workers_per_node=10)` - This starts a scheduler + process on the head node, and it starts an object store and some workers + on each node. - `cluster.stop_ray()` - This shuts down the cluster (killing all of the processes). + - `cluster.copy_code_to_cluster(user_source_directory)` - This copies the + contents of `user_source_directory` locally to the cluster under + `~/ray_source_files/`. - `cluster.update_ray()` - This pulls the latest Ray source code and builds it. - `cluster.run_command_over_ssh_on_all_nodes_in_parallel(command)` - This diff --git a/lib/python/ray/services.py b/lib/python/ray/services.py index 7b21cfcc6..985e99e81 100644 --- a/lib/python/ray/services.py +++ b/lib/python/ray/services.py @@ -82,7 +82,7 @@ def start_objstore(scheduler_address, node_ip_address, cleanup): if cleanup: all_processes.append(p) -def start_worker(node_ip_address, worker_path, scheduler_address, objstore_address=None, cleanup=True, user_source_directory=None): +def start_worker(node_ip_address, worker_path, scheduler_address, objstore_address=None, cleanup=True): """This method starts a worker process. Args: @@ -96,18 +96,10 @@ def start_worker(node_ip_address, worker_path, scheduler_address, objstore_addre cleanup (Optional[bool]): True if using Ray in local mode. If cleanup is true, then this process will be killed by serices.cleanup() when the Python process that imported services exits. This is True by default. - user_source_directory (Optional[str]): The directory containing the - application code. This directory will be added to the path of each worker. - If not provided, the directory of the script currently being run is used. """ - if user_source_directory is None: - # This extracts the directory of the script that is currently being run. - # This will allow users to import modules contained in this directory. - user_source_directory = os.path.dirname(os.path.abspath(os.path.join(os.path.curdir, sys.argv[0]))) command = ["python", worker_path, "--node-ip-address=" + node_ip_address, - "--user-source-directory=" + user_source_directory, "--scheduler-address=" + scheduler_address] if objstore_address is not None: command.append("--objstore-address=" + objstore_address) @@ -115,7 +107,7 @@ def start_worker(node_ip_address, worker_path, scheduler_address, objstore_addre if cleanup: all_processes.append(p) -def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None, user_source_directory=None, cleanup=False): +def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None, cleanup=False): """Start an object store and associated workers in the cluster setting. This starts an object store and the associated workers when Ray is being used @@ -129,8 +121,6 @@ def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None num_workers (int): The number of workers to be started on this node. worker_path (str): Path of the Python worker script that will be run on the worker. - user_source_directory (str): Path to the user's code the workers will import - modules from. cleanup (bool): If cleanup is True, then the processes started by this command will be killed when the process that imported services exits. """ @@ -139,7 +129,7 @@ def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None if worker_path is None: worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../scripts/default_worker.py") for _ in range(num_workers): - start_worker(node_ip_address, worker_path, scheduler_address, user_source_directory=user_source_directory, cleanup=cleanup) + start_worker(node_ip_address, worker_path, scheduler_address, cleanup=cleanup) time.sleep(0.5) def start_workers(scheduler_address, objstore_address, num_workers, worker_path): diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index deeca503a..696d520d2 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -1,4 +1,5 @@ import os +import sys import time import traceback import copy @@ -516,6 +517,23 @@ class Worker(object): objectids = raylib.submit_task(self.handle, task_capsule) return objectids + def run_function_on_all_workers(self, function): + """Run arbitrary code on all of the workers. + + This function will first be run on the driver, and then it will be exported + to all of the workers to be run. It will also be run on any new workers that + register later. + + Args: + function (Callable): The function to run on all of the workers. It should + not take any arguments. If it returns anything, its return values will + not be used. + """ + # First run the function on the driver. + function() + # Then run the function on all of the workers. + raylib.run_function_on_all_workers(self.handle, pickling.dumps(function)) + global_worker = Worker() """Worker: The global Worker object for this worker process. @@ -760,8 +778,18 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, worker=gl _logger().setLevel(logging.DEBUG) _logger().propagate = False if mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]: + # Add the directory containing the script that is running to the Python + # paths of the workers. Also add the current directory. Note that this + # assumes that the directory structures on the machines in the clusters are + # the same. + script_directory = os.path.abspath(os.path.dirname(sys.argv[0])) + current_directory = os.path.abspath(os.path.curdir) + worker.run_function_on_all_workers(lambda : sys.path.insert(1, script_directory)) + worker.run_function_on_all_workers(lambda : sys.path.insert(1, current_directory)) + # Export cached remote functions to the workers. for function_name, function_to_export in worker.cached_remote_functions: raylib.export_remote_function(worker.handle, function_name, function_to_export) + # Export cached reusable variables to the workers. for name, reusable_variable in reusables._cached_reusables: _export_reusable_variable(name, reusable_variable) worker.cached_remote_functions = None @@ -998,6 +1026,23 @@ def main_loop(worker=global_worker): else: _logger().info("Successfully imported reusable variable {}.".format(reusable_variable_name)) + def process_function_to_run(serialized_function): + """Run on arbitrary function on the worker.""" + try: + # Deserialize the function. + function = pickling.loads(serialized_function) + # Run the function. + function() + except: + # If an exception was thrown when the function was run, we record the + # traceback and notify the scheduler of the failure. + traceback_str = format_error_message(traceback.format_exc()) + _logger().info("Failed to run function on worker. Failed with message: \n\n{}\n".format(traceback_str)) + # Notify the scheduler that running the function failed. + # TODO(rkn): Notify the scheduler. + else: + _logger().info("Successfully ran function on worker.") + while True: command, command_args = raylib.wait_for_next_message(worker.handle) try: @@ -1013,6 +1058,9 @@ def main_loop(worker=global_worker): elif command == "reusable_variable": name, initializer_str, reinitializer_str = command_args process_reusable_variable(name, initializer_str, reinitializer_str) + elif command == "function_to_run": + serialized_function = command_args + process_function_to_run(serialized_function) else: _logger().info("Reached the end of the if-else loop in the main loop. This should be unreachable.") assert False, "This code should be unreachable." diff --git a/protos/ray.proto b/protos/ray.proto index cd2eec032..916e5d2d6 100644 --- a/protos/ray.proto +++ b/protos/ray.proto @@ -52,6 +52,8 @@ service Scheduler { rpc TaskInfo(TaskInfoRequest) returns (TaskInfoReply); // Kills the workers rpc KillWorkers(KillWorkersRequest) returns (KillWorkersReply); + // Run a function on all workers + rpc RunFunctionOnAllWorkers(RunFunctionOnAllWorkersRequest) returns (AckReply); // Exports function to the workers rpc ExportRemoteFunction(ExportRemoteFunctionRequest) returns (AckReply); // Ship an initializer and reinitializer for a reusable variable to the workers @@ -247,6 +249,10 @@ message KillWorkersReply { bool success = 1; // Currently, the only reason to fail is if there are workers still executing tasks } +message RunFunctionOnAllWorkersRequest { + Function function = 1; +} + message ExportRemoteFunctionRequest { Function function = 1; } @@ -274,6 +280,7 @@ message ObjStoreInfoReply { service WorkerService { rpc ExecuteTask(ExecuteTaskRequest) returns (AckReply); // Scheduler calls a function from the worker + rpc RunFunctionOnWorker(RunFunctionOnWorkerRequest) returns (AckReply); // Runs a function on the worker. rpc ImportRemoteFunction(ImportRemoteFunctionRequest) returns (AckReply); // Scheduler imports a function into the worker rpc ImportReusableVariable(ImportReusableVariableRequest) returns (AckReply); // Scheduler imports a reusable variable into the worker rpc Die(DieRequest) returns (AckReply); // Kills this worker @@ -284,6 +291,10 @@ message ExecuteTaskRequest { Task task = 1; // Contains name of the function to be executed and arguments } +message RunFunctionOnWorkerRequest { + Function function = 1; +} + message ImportRemoteFunctionRequest { Function function = 1; } @@ -302,6 +313,7 @@ message WorkerMessage { Task task = 1; // A task for the worker to execute. Function function = 2; // A remote function to import on the worker. ReusableVar reusable_variable = 3; // A reusable variable to import on the worker. + Function function_to_run = 4; // An arbitrary function to run on the worker. } } diff --git a/scripts/cluster.py b/scripts/cluster.py index f7a8483e0..9df802417 100644 --- a/scripts/cluster.py +++ b/scripts/cluster.py @@ -138,7 +138,7 @@ class RayCluster(object): """.format(self.installation_directory, self.installation_directory) self.run_command_over_ssh_on_all_nodes_in_parallel(install_ray_command) - def start_ray(self, user_source_directory=None, num_workers_per_node=10): + def start_ray(self, num_workers_per_node=10): """Start Ray on a cluster. This method is used to start Ray on a cluster. It will ssh to the head node, @@ -147,15 +147,8 @@ class RayCluster(object): workers. Args: - user_source_directory (Optional[str]): The path to the local directory - containing the user's source code. If provided, files and directories in - this directory can be used as modules in remote functions. num_workers_per_node (int): The number workers to start on each node. """ - # First update the worker code on the nodes. - if user_source_directory is not None: - remote_user_source_directory = self._update_user_code(user_source_directory) - scripts_directory = os.path.join(self.installation_directory, "ray/scripts") # Start the scheduler # The triple backslashes are used for two rounds of escaping, something like \\\" -> \" -> " @@ -169,18 +162,16 @@ class RayCluster(object): # Start the workers on each node # The triple backslashes are used for two rounds of escaping, something like \\\" -> \" -> " start_workers_commands = [] - remote_user_source_directory_str = "\\\"{}\\\"".format(remote_user_source_directory) if user_source_directory is not None else "None" for i, node_ip_address in enumerate(self.node_ip_addresses): start_workers_command = """ cd "{}"; source ../setup-env.sh; - python -c "import ray; ray.services.start_node(\\\"{}:10001\\\", \\\"{}\\\", {}, user_source_directory={})" > start_workers.out 2> start_workers.err < /dev/null & - """.format(scripts_directory, self.node_private_ip_addresses[0], self.node_private_ip_addresses[i], num_workers_per_node, remote_user_source_directory_str) + python -c "import ray; ray.services.start_node(\\\"{}:10001\\\", \\\"{}\\\", {})" > start_workers.out 2> start_workers.err < /dev/null & + """.format(scripts_directory, self.node_private_ip_addresses[0], self.node_private_ip_addresses[i], num_workers_per_node) start_workers_commands.append(start_workers_command) self.run_command_over_ssh_on_all_nodes_in_parallel(start_workers_commands) setup_env_path = os.path.join(self.installation_directory, "ray/setup-env.sh") - cd_location = remote_user_source_directory if user_source_directory is not None else os.path.join(self.installation_directory, "ray") print """ The cluster has been started. You can attach to the cluster by sshing to the head node with the following command. @@ -188,14 +179,13 @@ class RayCluster(object): Then run the following commands. - cd {} source {} # Add Ray to your Python path. Then within a Python interpreter or script, run the following commands. import ray ray.init(node_ip_address="{}", scheduler_address="{}:10001") - """.format(self.key_file, self.username, self.node_ip_addresses[0], cd_location, setup_env_path, self.node_private_ip_addresses[0], self.node_private_ip_addresses[0]) + """.format(self.key_file, self.username, self.node_ip_addresses[0], setup_env_path, self.node_private_ip_addresses[0], self.node_private_ip_addresses[0]) def stop_ray(self): """Kill all of the processes in the Ray cluster. @@ -230,15 +220,15 @@ class RayCluster(object): """.format(ray_directory, change_branch_command) self.run_command_over_ssh_on_all_nodes_in_parallel(update_cluster_command) - def _update_user_code(self, user_source_directory): + def copy_code_to_cluster(self, user_source_directory): """Update the user's source code on each node in the cluster. - This method is used to update the user's source code on each node in the + This method is used to copy the user's source code on each node in the cluster. The local user_source_directory will be copied under ray_source_files in the home directory on the worker node. For example, if - we call _update_source_code("~/a/b/c"), then the contents of "~/a/b/c" on - the local machine will be copied to "~/user_source_files/c" on each - node in the cluster. + we call copy_code_to_cluster("~/a/b/c"), then the contents of "~/a/b/c" on + the local machine will be copied to "~/ray_source_files/c" on each node in + the cluster. Args: user_source_directory (str): The path on the local machine to the directory @@ -253,7 +243,7 @@ class RayCluster(object): raise Exception("Directory {} does not exist.".format(user_source_directory)) # If user_source_directory is "/a/b/c", then local_directory_name is "c". local_directory_name = os.path.split(os.path.realpath(user_source_directory))[1] - remote_directory = os.path.join(self.installation_directory, "user_source_files", local_directory_name) + remote_directory = os.path.join(self.installation_directory, "ray_source_files", local_directory_name) # Remove and recreate the directory on the node. recreate_directory_command = """ rm -r "{}"; diff --git a/scripts/default_worker.py b/scripts/default_worker.py index 3c2706f2e..54f31d777 100644 --- a/scripts/default_worker.py +++ b/scripts/default_worker.py @@ -5,19 +5,12 @@ import numpy as np import ray parser = argparse.ArgumentParser(description="Parse addresses for the worker to connect to.") -parser.add_argument("--user-source-directory", type=str, help="the directory containing the user's application code") parser.add_argument("--node-ip-address", required=True, type=str, help="the ip address of the worker's node") parser.add_argument("--scheduler-address", required=True, type=str, help="the scheduler's address") parser.add_argument("--objstore-address", type=str, help="the objstore's address") if __name__ == "__main__": args = parser.parse_args() - if args.user_source_directory is not None: - # Adding the directory containing the user's application code to the Python - # path so that the worker can import Python modules from this directory. We - # insert into the first position (as opposed to the zeroth) because the - # zeroth position is reserved for the empty string. - sys.path.insert(1, args.user_source_directory) ray.worker.connect(args.node_ip_address, args.scheduler_address) ray.worker.main_loop() diff --git a/src/raylib.cc b/src/raylib.cc index 95144ceed..74e6aa959 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -718,7 +718,8 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) { bool task_present = !message->task().name().empty(); bool function_present = !message->function().implementation().empty(); bool reusable_variable_present = !message->reusable_variable().name().empty(); - RAY_CHECK(task_present + function_present + reusable_variable_present <= 1, "The worker message should contain at most one item."); + bool function_to_run_present = !message->function_to_run().implementation().empty(); + RAY_CHECK(task_present + function_present + reusable_variable_present + function_to_run_present <= 1, "The worker message should contain at most one item."); PyObject* t = PyTuple_New(2); if (task_present) { PyTuple_SetItem(t, 0, PyString_FromString("task")); @@ -736,6 +737,9 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) { PyTuple_SetItem(reusable_variable, 1, PyString_FromStringAndSize(message->reusable_variable().initializer().implementation().data(), static_cast(message->reusable_variable().initializer().implementation().size()))); PyTuple_SetItem(reusable_variable, 2, PyString_FromStringAndSize(message->reusable_variable().reinitializer().implementation().data(), static_cast(message->reusable_variable().reinitializer().implementation().size()))); PyTuple_SetItem(t, 1, reusable_variable); + } else if (function_to_run_present) { + PyTuple_SetItem(t, 0, PyString_FromString("function_to_run")); + PyTuple_SetItem(t, 1, PyString_FromStringAndSize(message->function_to_run().implementation().data(), static_cast(message->function_to_run().implementation().size()))); } else { PyTuple_SetItem(t, 0, PyString_FromString("die")); Py_INCREF(Py_None); @@ -747,6 +751,17 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) { Py_RETURN_NONE; } +static PyObject* run_function_on_all_workers(PyObject* self, PyObject* args) { + Worker* worker; + const char* function; + int function_size; + if (!PyArg_ParseTuple(args, "O&s#", &PyObjectToWorker, &worker, &function, &function_size)) { + return NULL; + } + worker->run_function_on_all_workers(std::string(function, static_cast(function_size))); + Py_RETURN_NONE; +} + static PyObject* export_remote_function(PyObject* self, PyObject* args) { Worker* worker; const char* function_name; @@ -1088,6 +1103,7 @@ static PyMethodDef RayLibMethods[] = { { "ready_for_new_task", ready_for_new_task, METH_VARARGS, "notify the scheduler that the worker is ready for a new task" }, { "scheduler_info", scheduler_info, METH_VARARGS, "get info about scheduler state" }, { "task_info", task_info, METH_VARARGS, "get information about task statuses and failures" }, + { "run_function_on_all_workers", run_function_on_all_workers, METH_VARARGS, "run an arbitrary function on all workers" }, { "export_remote_function", export_remote_function, METH_VARARGS, "export a remote function to workers" }, { "export_reusable_variable", export_reusable_variable, METH_VARARGS, "export a reusable variable to the workers" }, { "dump_computation_graph", dump_computation_graph, METH_VARARGS, "dump the current computation graph to a file" }, diff --git a/src/scheduler.cc b/src/scheduler.cc index 888fb2581..3faac2dbe 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -365,6 +365,8 @@ Status SchedulerService::ReadyForNewTask(ServerContext* context, const ReadyForN // all of the exported functions and all of the exported reusable variables. if (!(*workers)[workerid].initialized) { // This should only happen once. + // Queue up all functions to run on the worker. + add_all_functions_to_run_to_worker_queue(workerid); // Queue up all remote functions to be imported on the worker. add_all_remote_functions_to_worker_export_queue(workerid); // Queue up all reusable variables to be imported on the worker. @@ -513,6 +515,22 @@ Status SchedulerService::KillWorkers(ServerContext* context, const KillWorkersRe return Status::OK; } +Status SchedulerService::RunFunctionOnAllWorkers(ServerContext* context, const RunFunctionOnAllWorkersRequest* request, AckReply* reply) { + { + auto workers = GET(workers_); + auto function_to_run_queue = GET(function_to_run_queue_); + auto exported_functions_to_run = GET(exported_functions_to_run_); + exported_functions_to_run->push_back(std::unique_ptr(new Function(request->function()))); + for (WorkerId workerid = 0; workerid < workers->size(); ++workerid) { + if ((*workers)[workerid].current_task != ROOT_OPERATION) { + function_to_run_queue->push(std::make_pair(workerid, exported_functions_to_run->size() - 1)); + } + } + } + schedule(); + return Status::OK; +} + Status SchedulerService::ExportRemoteFunction(ServerContext* context, const ExportRemoteFunctionRequest* request, AckReply* reply) { { auto workers = GET(workers_); @@ -609,6 +627,10 @@ void SchedulerService::deliver_object_async(ObjectID canonical_objectid, ObjStor } void SchedulerService::schedule() { + // Run functions on workers. This must happen before we schedule tasks in + // order to guarantee that remote function calls use the most up to date + // environment. + perform_functions_to_run(); // Export remote functions to the workers. This must happen before we schedule // tasks in order to guarantee that remote function calls use the most up to // date definitions. @@ -801,6 +823,17 @@ bool SchedulerService::is_canonical(ObjectID objectid) { return objectid == (*target_objectids)[objectid]; } +void SchedulerService::perform_functions_to_run() { + auto workers = GET(workers_); + auto function_to_run_queue = GET(function_to_run_queue_); + auto exported_functions_to_run = GET(exported_functions_to_run_); + while (!function_to_run_queue->empty()) { + std::pair workerid_functionid_pair = function_to_run_queue->front(); + export_function_to_run_to_worker(workerid_functionid_pair.first, workerid_functionid_pair.second, workers, exported_functions_to_run); + function_to_run_queue->pop(); + } +} + void SchedulerService::perform_remote_function_exports() { auto workers = GET(workers_); auto remote_function_export_queue = GET(remote_function_export_queue_); @@ -1084,22 +1117,39 @@ void SchedulerService::get_equivalent_objectids(ObjectID objectid, std::vector > &workers, const MySynchronizedPtr > > &exported_functions_to_run) { + RAY_LOG(RAY_INFO, "exporting function to run with index " << function_index << " to worker " << workerid); + ClientContext context; + RunFunctionOnWorkerRequest request; + request.mutable_function()->CopyFrom(*(*exported_functions_to_run)[function_index].get()); + AckReply reply; + RAY_CHECK_GRPC((*workers)[workerid].worker_stub->RunFunctionOnWorker(&context, request, &reply)); +} + void SchedulerService::export_function_to_worker(WorkerId workerid, int function_index, MySynchronizedPtr > &workers, const MySynchronizedPtr > > &exported_functions) { - RAY_LOG(RAY_INFO, "exporting function with index " << function_index << " to worker " << workerid); - ClientContext import_context; - ImportRemoteFunctionRequest import_request; - import_request.mutable_function()->CopyFrom(*(*exported_functions)[function_index].get()); - AckReply import_reply; - RAY_CHECK_GRPC((*workers)[workerid].worker_stub->ImportRemoteFunction(&import_context, import_request, &import_reply)); + RAY_LOG(RAY_INFO, "exporting remote function with index " << function_index << " to worker " << workerid); + ClientContext context; + ImportRemoteFunctionRequest request; + request.mutable_function()->CopyFrom(*(*exported_functions)[function_index].get()); + AckReply reply; + RAY_CHECK_GRPC((*workers)[workerid].worker_stub->ImportRemoteFunction(&context, request, &reply)); } void SchedulerService::export_reusable_variable_to_worker(WorkerId workerid, int reusable_variable_index, MySynchronizedPtr > &workers, const MySynchronizedPtr > > &exported_reusable_variables) { RAY_LOG(RAY_INFO, "exporting reusable variable with index " << reusable_variable_index << " to worker " << workerid); - ClientContext import_context; - ImportReusableVariableRequest import_request; - import_request.mutable_reusable_variable()->CopyFrom(*(*exported_reusable_variables)[reusable_variable_index].get()); - AckReply import_reply; - RAY_CHECK_GRPC((*workers)[workerid].worker_stub->ImportReusableVariable(&import_context, import_request, &import_reply)); + ClientContext context; + ImportReusableVariableRequest request; + request.mutable_reusable_variable()->CopyFrom(*(*exported_reusable_variables)[reusable_variable_index].get()); + AckReply reply; + RAY_CHECK_GRPC((*workers)[workerid].worker_stub->ImportReusableVariable(&context, request, &reply)); +} + +void SchedulerService::add_all_functions_to_run_to_worker_queue(WorkerId workerid) { + auto function_to_run_queue = GET(function_to_run_queue_); + auto exported_functions_to_run = GET(exported_functions_to_run_); + for (int i = 0; i < exported_functions_to_run->size(); ++i) { + function_to_run_queue->push(std::make_pair(workerid, i)); + } } void SchedulerService::add_all_remote_functions_to_worker_export_queue(WorkerId workerid) { diff --git a/src/scheduler.h b/src/scheduler.h index aaac29044..bccc321a4 100644 --- a/src/scheduler.h +++ b/src/scheduler.h @@ -75,6 +75,7 @@ public: Status SchedulerInfo(ServerContext* context, const SchedulerInfoRequest* request, SchedulerInfoReply* reply) override; Status TaskInfo(ServerContext* context, const TaskInfoRequest* request, TaskInfoReply* reply) override; Status KillWorkers(ServerContext* context, const KillWorkersRequest* request, KillWorkersReply* reply) override; + Status RunFunctionOnAllWorkers(ServerContext* context, const RunFunctionOnAllWorkersRequest* request, AckReply* reply) override; Status ExportRemoteFunction(ServerContext* context, const ExportRemoteFunctionRequest* request, AckReply* reply) override; Status ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) override; Status NotifyFailure(ServerContext*, const NotifyFailureRequest* request, AckReply* reply) override; @@ -119,10 +120,13 @@ private: ObjStoreId pick_objstore(ObjectID objectid); // checks if objectid is a canonical objectid bool is_canonical(ObjectID objectid); + // Export all queued up functions to run. + void perform_functions_to_run(); // Export all queued up remote functions. void perform_remote_function_exports(); // Export all queued up reusable variables. void perform_reusable_variable_exports(); + // Perform all queued up gets that can be performed. void perform_gets(); // schedule tasks using the naive algorithm void schedule_tasks_naively(); @@ -148,10 +152,15 @@ private: void upstream_objectids(ObjectID objectid, std::vector &objectids, const MySynchronizedPtr > > &reverse_target_objectids); // Find all of the object IDs that refer to the same object as objectid (as best as we can determine at the moment). The information may be incomplete because not all of the aliases may be known. void get_equivalent_objectids(ObjectID objectid, std::vector &equivalent_objectids); + // Export a function to run to a worker. + void export_function_to_run_to_worker(WorkerId workerid, int function_index, MySynchronizedPtr > &workers, const MySynchronizedPtr > > &exported_functions_to_run); // Export a remote function to a worker. void export_function_to_worker(WorkerId workerid, int function_index, MySynchronizedPtr > &workers, const MySynchronizedPtr > > &exported_functions); // Export a reusable variable to a worker void export_reusable_variable_to_worker(WorkerId workerid, int reusable_variable_index, MySynchronizedPtr > &workers, const MySynchronizedPtr > > &exported_reusable_variables); + // Add to the function to run export queue the job of exporting all functions + // to run to the given worker. This is used when a new worker registers. + void add_all_functions_to_run_to_worker_queue(WorkerId workerid); // Add to the remote function export queue the job of exporting all remote // functions to the given worker. This is used when a new worker registers. void add_all_remote_functions_to_worker_export_queue(WorkerId workerid); @@ -226,6 +235,11 @@ private: // lock (objects_lock_). // TODO(rkn): Consider making this part of the // objtable data structure. std::vector > objects_in_transit_; + // List of pending functions to run on workers. These should be processed in a + // first in first out manner. The first element of each pair is the ID of the + // worker to run the function on, and the second element of each pair is the + // index of the function to run. + Synchronized > > function_to_run_queue_; // List of pending remote function exports. These should be processed in a // first in first out manner. The first element of each pair is the ID of the // worker to export the remote function to, and the second element of each @@ -236,6 +250,8 @@ private: // worker to export the reusable variable to, and the second element of each // pair is the index of the reusable variable to export. Synchronized > > reusable_variable_export_queue_; + // All of the functions that have been exported to the workers to run. + Synchronized > > exported_functions_to_run_; // All of the remote functions that have been exported to the workers. Synchronized > > exported_functions_; // All of the reusable variables that have been exported to the workers. diff --git a/src/worker.cc b/src/worker.cc index 64d0b4062..eb7c9c6cf 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -26,6 +26,21 @@ Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskR WorkerMessage* message_ptr = message.get(); RAY_CHECK(send_queue_.send(&message_ptr), "Failed to send message from the worker service to the worker because the message queue was full."); } + // The message will get deleted in receive_next_message(). + message.release(); + return Status::OK; +} + +Status WorkerServiceImpl::RunFunctionOnWorker(ServerContext* context, const RunFunctionOnWorkerRequest* request, AckReply* reply) { + RAY_CHECK(mode_ == Mode::WORKER_MODE, "RunFunctionOnWorker can only be called on workers."); + std::unique_ptr message(new WorkerMessage()); + message->mutable_function_to_run()->CopyFrom(request->function()); + RAY_LOG(RAY_INFO, "Running function on worker."); + { + WorkerMessage* message_ptr = message.get(); + RAY_CHECK(send_queue_.send(&message_ptr), "Failed to send message from the worker service to the worker because the message queue was full."); + } + // The message will get deleted in receive_next_message(). message.release(); return Status::OK; } @@ -39,6 +54,7 @@ Status WorkerServiceImpl::ImportRemoteFunction(ServerContext* context, const Imp WorkerMessage* message_ptr = message.get(); RAY_CHECK(send_queue_.send(&message_ptr), "Failed to send message from the worker service to the worker because the message queue was full."); } + // The message will get deleted in receive_next_message(). message.release(); return Status::OK; } @@ -52,6 +68,7 @@ Status WorkerServiceImpl::ImportReusableVariable(ServerContext* context, const I WorkerMessage* message_ptr = message.get(); RAY_CHECK(send_queue_.send(&message_ptr), "Failed to send message from the worker service to the worker because the message queue was full."); } + // The message will get deleted in receive_next_message(). message.release(); return Status::OK; } @@ -440,6 +457,15 @@ std::vector Worker::select(std::vector& objectids) { return result; } +void Worker::run_function_on_all_workers(const std::string& function) { + RAY_CHECK(connected_, "Attempted to run function on all workers but failed."); + ClientContext context; + RunFunctionOnAllWorkersRequest request; + request.mutable_function()->set_implementation(function); + AckReply reply; + RAY_CHECK_GRPC(scheduler_stub_->RunFunctionOnAllWorkers(&context, request, &reply)); +} + bool Worker::export_remote_function(const std::string& function_name, const std::string& function) { RAY_CHECK(connected_, "Attempted to export function but failed."); ClientContext context; diff --git a/src/worker.h b/src/worker.h index 814cc1041..b7a86141a 100644 --- a/src/worker.h +++ b/src/worker.h @@ -32,6 +32,7 @@ class WorkerServiceImpl final : public WorkerService::Service { public: WorkerServiceImpl(const std::string& worker_address, Mode mode); Status ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, AckReply* reply) override; + Status RunFunctionOnWorker(ServerContext* context, const RunFunctionOnWorkerRequest* request, AckReply* reply) override; Status ImportRemoteFunction(ServerContext* context, const ImportRemoteFunctionRequest* request, AckReply* reply) override; Status Die(ServerContext* context, const DieRequest* request, AckReply* reply) override; Status ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) override; @@ -104,6 +105,8 @@ class Worker { void task_info(ClientContext &context, TaskInfoRequest &request, TaskInfoReply &reply); // gets indices of available objects std::vector select(std::vector& objectids); + // Export a function to be run on all workers. + void run_function_on_all_workers(const std::string& function); // export function to workers bool export_remote_function(const std::string& function_name, const std::string& function); // export reusable variable to workers diff --git a/test/runtest.py b/test/runtest.py index 13f7a5500..fd26e2f8e 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -351,6 +351,23 @@ class APITest(unittest.TestCase): ray.worker.cleanup() + def testRunningFunctionOnAllWorkers(self): + ray.init(start_ray_local=True, num_workers=1) + + def f(): + sys.path.append("fake_directory") + ray.worker.global_worker.run_function_on_all_workers(f) + @ray.remote([], [list]) + def get_path(): + return sys.path + self.assertEqual("fake_directory", ray.get(get_path.remote())[-1]) + def f(): + sys.path.pop(-1) + ray.worker.global_worker.run_function_on_all_workers(f) + self.assertTrue("fake_directory" not in ray.get(get_path.remote())) + + ray.worker.cleanup() + def check_get_deallocated(data): x = ray.put(data) ray.get(x)