Automatically add relevant directories to Python paths of workers (#380)

* Make ray.init set python paths of workers.

* Decouple starting cluster from copying user source code

* also add current directory to path

* Add comments about deallocation.

* Add test for new code path.
This commit is contained in:
Robert Nishihara
2016-08-16 14:53:55 -07:00
committed by Philipp Moritz
parent 7246013008
commit e06311d415
12 changed files with 222 additions and 64 deletions
+9 -12
View File
@@ -127,14 +127,9 @@ the tests.
python test/array_test.py # This tests some array libraries.
```
6. Start the cluster with `cluster.start_ray()`. If you would like to deploy
source code to it, you can pass in the local path to the directory that contains
your Python code. For example, `cluster.start_ray("~/example_ray_code")`. This
will copy your source code to each node on the cluster, placing it in a
directory on the PYTHONPATH.
The `cluster.start_ray` command will start the Ray scheduler, object stores, and
workers, and before finishing it will print instructions for connecting to the
cluster via ssh.
6. Start the cluster with `cluster.start_ray()`. The `cluster.start_ray` command
will start the Ray scheduler, object stores, and workers, and before finishing
it will print instructions for connecting to the cluster via ssh.
7. To connect to the cluster (either with a Python shell or with a script), ssh
to the cluster's head node (as described by the output of the
@@ -146,7 +141,6 @@ to the cluster's head node (as described by the output of the
Then run the following commands.
cd $HOME/ray
source $HOME/ray/setup-env.sh # Add Ray to your Python path.
Then within a Python interpreter, run the following commands.
@@ -177,11 +171,14 @@ need to install a few more Python packages. This can be done, within
- `cluster.install_ray()` - This pulls the Ray source code on each node,
builds all of the third party libraries, and builds the project itself.
- `cluster.start_ray(user_source_directory=None, num_workers_per_node=10)` -
This starts a scheduler process on the head node, and it starts an object
store and some workers on each node.
- `cluster.start_ray(num_workers_per_node=10)` - This starts a scheduler
process on the head node, and it starts an object store and some workers
on each node.
- `cluster.stop_ray()` - This shuts down the cluster (killing all of the
processes).
- `cluster.copy_code_to_cluster(user_source_directory)` - This copies the
contents of `user_source_directory` locally to the cluster under
`~/ray_source_files/`.
- `cluster.update_ray()` - This pulls the latest Ray source code and builds
it.
- `cluster.run_command_over_ssh_on_all_nodes_in_parallel(command)` - This
+3 -13
View File
@@ -82,7 +82,7 @@ def start_objstore(scheduler_address, node_ip_address, cleanup):
if cleanup:
all_processes.append(p)
def start_worker(node_ip_address, worker_path, scheduler_address, objstore_address=None, cleanup=True, user_source_directory=None):
def start_worker(node_ip_address, worker_path, scheduler_address, objstore_address=None, cleanup=True):
"""This method starts a worker process.
Args:
@@ -96,18 +96,10 @@ def start_worker(node_ip_address, worker_path, scheduler_address, objstore_addre
cleanup (Optional[bool]): True if using Ray in local mode. If cleanup is
true, then this process will be killed by serices.cleanup() when the
Python process that imported services exits. This is True by default.
user_source_directory (Optional[str]): The directory containing the
application code. This directory will be added to the path of each worker.
If not provided, the directory of the script currently being run is used.
"""
if user_source_directory is None:
# This extracts the directory of the script that is currently being run.
# This will allow users to import modules contained in this directory.
user_source_directory = os.path.dirname(os.path.abspath(os.path.join(os.path.curdir, sys.argv[0])))
command = ["python",
worker_path,
"--node-ip-address=" + node_ip_address,
"--user-source-directory=" + user_source_directory,
"--scheduler-address=" + scheduler_address]
if objstore_address is not None:
command.append("--objstore-address=" + objstore_address)
@@ -115,7 +107,7 @@ def start_worker(node_ip_address, worker_path, scheduler_address, objstore_addre
if cleanup:
all_processes.append(p)
def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None, user_source_directory=None, cleanup=False):
def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None, cleanup=False):
"""Start an object store and associated workers in the cluster setting.
This starts an object store and the associated workers when Ray is being used
@@ -129,8 +121,6 @@ def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None
num_workers (int): The number of workers to be started on this node.
worker_path (str): Path of the Python worker script that will be run on the
worker.
user_source_directory (str): Path to the user's code the workers will import
modules from.
cleanup (bool): If cleanup is True, then the processes started by this
command will be killed when the process that imported services exits.
"""
@@ -139,7 +129,7 @@ def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None
if worker_path is None:
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../scripts/default_worker.py")
for _ in range(num_workers):
start_worker(node_ip_address, worker_path, scheduler_address, user_source_directory=user_source_directory, cleanup=cleanup)
start_worker(node_ip_address, worker_path, scheduler_address, cleanup=cleanup)
time.sleep(0.5)
def start_workers(scheduler_address, objstore_address, num_workers, worker_path):
+48
View File
@@ -1,4 +1,5 @@
import os
import sys
import time
import traceback
import copy
@@ -516,6 +517,23 @@ class Worker(object):
objectids = raylib.submit_task(self.handle, task_capsule)
return objectids
def run_function_on_all_workers(self, function):
"""Run arbitrary code on all of the workers.
This function will first be run on the driver, and then it will be exported
to all of the workers to be run. It will also be run on any new workers that
register later.
Args:
function (Callable): The function to run on all of the workers. It should
not take any arguments. If it returns anything, its return values will
not be used.
"""
# First run the function on the driver.
function()
# Then run the function on all of the workers.
raylib.run_function_on_all_workers(self.handle, pickling.dumps(function))
global_worker = Worker()
"""Worker: The global Worker object for this worker process.
@@ -760,8 +778,18 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, worker=gl
_logger().setLevel(logging.DEBUG)
_logger().propagate = False
if mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
# Add the directory containing the script that is running to the Python
# paths of the workers. Also add the current directory. Note that this
# assumes that the directory structures on the machines in the clusters are
# the same.
script_directory = os.path.abspath(os.path.dirname(sys.argv[0]))
current_directory = os.path.abspath(os.path.curdir)
worker.run_function_on_all_workers(lambda : sys.path.insert(1, script_directory))
worker.run_function_on_all_workers(lambda : sys.path.insert(1, current_directory))
# Export cached remote functions to the workers.
for function_name, function_to_export in worker.cached_remote_functions:
raylib.export_remote_function(worker.handle, function_name, function_to_export)
# Export cached reusable variables to the workers.
for name, reusable_variable in reusables._cached_reusables:
_export_reusable_variable(name, reusable_variable)
worker.cached_remote_functions = None
@@ -998,6 +1026,23 @@ def main_loop(worker=global_worker):
else:
_logger().info("Successfully imported reusable variable {}.".format(reusable_variable_name))
def process_function_to_run(serialized_function):
"""Run on arbitrary function on the worker."""
try:
# Deserialize the function.
function = pickling.loads(serialized_function)
# Run the function.
function()
except:
# If an exception was thrown when the function was run, we record the
# traceback and notify the scheduler of the failure.
traceback_str = format_error_message(traceback.format_exc())
_logger().info("Failed to run function on worker. Failed with message: \n\n{}\n".format(traceback_str))
# Notify the scheduler that running the function failed.
# TODO(rkn): Notify the scheduler.
else:
_logger().info("Successfully ran function on worker.")
while True:
command, command_args = raylib.wait_for_next_message(worker.handle)
try:
@@ -1013,6 +1058,9 @@ def main_loop(worker=global_worker):
elif command == "reusable_variable":
name, initializer_str, reinitializer_str = command_args
process_reusable_variable(name, initializer_str, reinitializer_str)
elif command == "function_to_run":
serialized_function = command_args
process_function_to_run(serialized_function)
else:
_logger().info("Reached the end of the if-else loop in the main loop. This should be unreachable.")
assert False, "This code should be unreachable."
+12
View File
@@ -52,6 +52,8 @@ service Scheduler {
rpc TaskInfo(TaskInfoRequest) returns (TaskInfoReply);
// Kills the workers
rpc KillWorkers(KillWorkersRequest) returns (KillWorkersReply);
// Run a function on all workers
rpc RunFunctionOnAllWorkers(RunFunctionOnAllWorkersRequest) returns (AckReply);
// Exports function to the workers
rpc ExportRemoteFunction(ExportRemoteFunctionRequest) returns (AckReply);
// Ship an initializer and reinitializer for a reusable variable to the workers
@@ -247,6 +249,10 @@ message KillWorkersReply {
bool success = 1; // Currently, the only reason to fail is if there are workers still executing tasks
}
message RunFunctionOnAllWorkersRequest {
Function function = 1;
}
message ExportRemoteFunctionRequest {
Function function = 1;
}
@@ -274,6 +280,7 @@ message ObjStoreInfoReply {
service WorkerService {
rpc ExecuteTask(ExecuteTaskRequest) returns (AckReply); // Scheduler calls a function from the worker
rpc RunFunctionOnWorker(RunFunctionOnWorkerRequest) returns (AckReply); // Runs a function on the worker.
rpc ImportRemoteFunction(ImportRemoteFunctionRequest) returns (AckReply); // Scheduler imports a function into the worker
rpc ImportReusableVariable(ImportReusableVariableRequest) returns (AckReply); // Scheduler imports a reusable variable into the worker
rpc Die(DieRequest) returns (AckReply); // Kills this worker
@@ -284,6 +291,10 @@ message ExecuteTaskRequest {
Task task = 1; // Contains name of the function to be executed and arguments
}
message RunFunctionOnWorkerRequest {
Function function = 1;
}
message ImportRemoteFunctionRequest {
Function function = 1;
}
@@ -302,6 +313,7 @@ message WorkerMessage {
Task task = 1; // A task for the worker to execute.
Function function = 2; // A remote function to import on the worker.
ReusableVar reusable_variable = 3; // A reusable variable to import on the worker.
Function function_to_run = 4; // An arbitrary function to run on the worker.
}
}
+10 -20
View File
@@ -138,7 +138,7 @@ class RayCluster(object):
""".format(self.installation_directory, self.installation_directory)
self.run_command_over_ssh_on_all_nodes_in_parallel(install_ray_command)
def start_ray(self, user_source_directory=None, num_workers_per_node=10):
def start_ray(self, num_workers_per_node=10):
"""Start Ray on a cluster.
This method is used to start Ray on a cluster. It will ssh to the head node,
@@ -147,15 +147,8 @@ class RayCluster(object):
workers.
Args:
user_source_directory (Optional[str]): The path to the local directory
containing the user's source code. If provided, files and directories in
this directory can be used as modules in remote functions.
num_workers_per_node (int): The number workers to start on each node.
"""
# First update the worker code on the nodes.
if user_source_directory is not None:
remote_user_source_directory = self._update_user_code(user_source_directory)
scripts_directory = os.path.join(self.installation_directory, "ray/scripts")
# Start the scheduler
# The triple backslashes are used for two rounds of escaping, something like \\\" -> \" -> "
@@ -169,18 +162,16 @@ class RayCluster(object):
# Start the workers on each node
# The triple backslashes are used for two rounds of escaping, something like \\\" -> \" -> "
start_workers_commands = []
remote_user_source_directory_str = "\\\"{}\\\"".format(remote_user_source_directory) if user_source_directory is not None else "None"
for i, node_ip_address in enumerate(self.node_ip_addresses):
start_workers_command = """
cd "{}";
source ../setup-env.sh;
python -c "import ray; ray.services.start_node(\\\"{}:10001\\\", \\\"{}\\\", {}, user_source_directory={})" > start_workers.out 2> start_workers.err < /dev/null &
""".format(scripts_directory, self.node_private_ip_addresses[0], self.node_private_ip_addresses[i], num_workers_per_node, remote_user_source_directory_str)
python -c "import ray; ray.services.start_node(\\\"{}:10001\\\", \\\"{}\\\", {})" > start_workers.out 2> start_workers.err < /dev/null &
""".format(scripts_directory, self.node_private_ip_addresses[0], self.node_private_ip_addresses[i], num_workers_per_node)
start_workers_commands.append(start_workers_command)
self.run_command_over_ssh_on_all_nodes_in_parallel(start_workers_commands)
setup_env_path = os.path.join(self.installation_directory, "ray/setup-env.sh")
cd_location = remote_user_source_directory if user_source_directory is not None else os.path.join(self.installation_directory, "ray")
print """
The cluster has been started. You can attach to the cluster by sshing to the head node with the following command.
@@ -188,14 +179,13 @@ class RayCluster(object):
Then run the following commands.
cd {}
source {} # Add Ray to your Python path.
Then within a Python interpreter or script, run the following commands.
import ray
ray.init(node_ip_address="{}", scheduler_address="{}:10001")
""".format(self.key_file, self.username, self.node_ip_addresses[0], cd_location, setup_env_path, self.node_private_ip_addresses[0], self.node_private_ip_addresses[0])
""".format(self.key_file, self.username, self.node_ip_addresses[0], setup_env_path, self.node_private_ip_addresses[0], self.node_private_ip_addresses[0])
def stop_ray(self):
"""Kill all of the processes in the Ray cluster.
@@ -230,15 +220,15 @@ class RayCluster(object):
""".format(ray_directory, change_branch_command)
self.run_command_over_ssh_on_all_nodes_in_parallel(update_cluster_command)
def _update_user_code(self, user_source_directory):
def copy_code_to_cluster(self, user_source_directory):
"""Update the user's source code on each node in the cluster.
This method is used to update the user's source code on each node in the
This method is used to copy the user's source code on each node in the
cluster. The local user_source_directory will be copied under
ray_source_files in the home directory on the worker node. For example, if
we call _update_source_code("~/a/b/c"), then the contents of "~/a/b/c" on
the local machine will be copied to "~/user_source_files/c" on each
node in the cluster.
we call copy_code_to_cluster("~/a/b/c"), then the contents of "~/a/b/c" on
the local machine will be copied to "~/ray_source_files/c" on each node in
the cluster.
Args:
user_source_directory (str): The path on the local machine to the directory
@@ -253,7 +243,7 @@ class RayCluster(object):
raise Exception("Directory {} does not exist.".format(user_source_directory))
# If user_source_directory is "/a/b/c", then local_directory_name is "c".
local_directory_name = os.path.split(os.path.realpath(user_source_directory))[1]
remote_directory = os.path.join(self.installation_directory, "user_source_files", local_directory_name)
remote_directory = os.path.join(self.installation_directory, "ray_source_files", local_directory_name)
# Remove and recreate the directory on the node.
recreate_directory_command = """
rm -r "{}";
-7
View File
@@ -5,19 +5,12 @@ import numpy as np
import ray
parser = argparse.ArgumentParser(description="Parse addresses for the worker to connect to.")
parser.add_argument("--user-source-directory", type=str, help="the directory containing the user's application code")
parser.add_argument("--node-ip-address", required=True, type=str, help="the ip address of the worker's node")
parser.add_argument("--scheduler-address", required=True, type=str, help="the scheduler's address")
parser.add_argument("--objstore-address", type=str, help="the objstore's address")
if __name__ == "__main__":
args = parser.parse_args()
if args.user_source_directory is not None:
# Adding the directory containing the user's application code to the Python
# path so that the worker can import Python modules from this directory. We
# insert into the first position (as opposed to the zeroth) because the
# zeroth position is reserved for the empty string.
sys.path.insert(1, args.user_source_directory)
ray.worker.connect(args.node_ip_address, args.scheduler_address)
ray.worker.main_loop()
+17 -1
View File
@@ -718,7 +718,8 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) {
bool task_present = !message->task().name().empty();
bool function_present = !message->function().implementation().empty();
bool reusable_variable_present = !message->reusable_variable().name().empty();
RAY_CHECK(task_present + function_present + reusable_variable_present <= 1, "The worker message should contain at most one item.");
bool function_to_run_present = !message->function_to_run().implementation().empty();
RAY_CHECK(task_present + function_present + reusable_variable_present + function_to_run_present <= 1, "The worker message should contain at most one item.");
PyObject* t = PyTuple_New(2);
if (task_present) {
PyTuple_SetItem(t, 0, PyString_FromString("task"));
@@ -736,6 +737,9 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) {
PyTuple_SetItem(reusable_variable, 1, PyString_FromStringAndSize(message->reusable_variable().initializer().implementation().data(), static_cast<ssize_t>(message->reusable_variable().initializer().implementation().size())));
PyTuple_SetItem(reusable_variable, 2, PyString_FromStringAndSize(message->reusable_variable().reinitializer().implementation().data(), static_cast<ssize_t>(message->reusable_variable().reinitializer().implementation().size())));
PyTuple_SetItem(t, 1, reusable_variable);
} else if (function_to_run_present) {
PyTuple_SetItem(t, 0, PyString_FromString("function_to_run"));
PyTuple_SetItem(t, 1, PyString_FromStringAndSize(message->function_to_run().implementation().data(), static_cast<ssize_t>(message->function_to_run().implementation().size())));
} else {
PyTuple_SetItem(t, 0, PyString_FromString("die"));
Py_INCREF(Py_None);
@@ -747,6 +751,17 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) {
Py_RETURN_NONE;
}
static PyObject* run_function_on_all_workers(PyObject* self, PyObject* args) {
Worker* worker;
const char* function;
int function_size;
if (!PyArg_ParseTuple(args, "O&s#", &PyObjectToWorker, &worker, &function, &function_size)) {
return NULL;
}
worker->run_function_on_all_workers(std::string(function, static_cast<size_t>(function_size)));
Py_RETURN_NONE;
}
static PyObject* export_remote_function(PyObject* self, PyObject* args) {
Worker* worker;
const char* function_name;
@@ -1088,6 +1103,7 @@ static PyMethodDef RayLibMethods[] = {
{ "ready_for_new_task", ready_for_new_task, METH_VARARGS, "notify the scheduler that the worker is ready for a new task" },
{ "scheduler_info", scheduler_info, METH_VARARGS, "get info about scheduler state" },
{ "task_info", task_info, METH_VARARGS, "get information about task statuses and failures" },
{ "run_function_on_all_workers", run_function_on_all_workers, METH_VARARGS, "run an arbitrary function on all workers" },
{ "export_remote_function", export_remote_function, METH_VARARGS, "export a remote function to workers" },
{ "export_reusable_variable", export_reusable_variable, METH_VARARGS, "export a reusable variable to the workers" },
{ "dump_computation_graph", dump_computation_graph, METH_VARARGS, "dump the current computation graph to a file" },
+61 -11
View File
@@ -365,6 +365,8 @@ Status SchedulerService::ReadyForNewTask(ServerContext* context, const ReadyForN
// all of the exported functions and all of the exported reusable variables.
if (!(*workers)[workerid].initialized) {
// This should only happen once.
// Queue up all functions to run on the worker.
add_all_functions_to_run_to_worker_queue(workerid);
// Queue up all remote functions to be imported on the worker.
add_all_remote_functions_to_worker_export_queue(workerid);
// Queue up all reusable variables to be imported on the worker.
@@ -513,6 +515,22 @@ Status SchedulerService::KillWorkers(ServerContext* context, const KillWorkersRe
return Status::OK;
}
Status SchedulerService::RunFunctionOnAllWorkers(ServerContext* context, const RunFunctionOnAllWorkersRequest* request, AckReply* reply) {
{
auto workers = GET(workers_);
auto function_to_run_queue = GET(function_to_run_queue_);
auto exported_functions_to_run = GET(exported_functions_to_run_);
exported_functions_to_run->push_back(std::unique_ptr<Function>(new Function(request->function())));
for (WorkerId workerid = 0; workerid < workers->size(); ++workerid) {
if ((*workers)[workerid].current_task != ROOT_OPERATION) {
function_to_run_queue->push(std::make_pair(workerid, exported_functions_to_run->size() - 1));
}
}
}
schedule();
return Status::OK;
}
Status SchedulerService::ExportRemoteFunction(ServerContext* context, const ExportRemoteFunctionRequest* request, AckReply* reply) {
{
auto workers = GET(workers_);
@@ -609,6 +627,10 @@ void SchedulerService::deliver_object_async(ObjectID canonical_objectid, ObjStor
}
void SchedulerService::schedule() {
// Run functions on workers. This must happen before we schedule tasks in
// order to guarantee that remote function calls use the most up to date
// environment.
perform_functions_to_run();
// Export remote functions to the workers. This must happen before we schedule
// tasks in order to guarantee that remote function calls use the most up to
// date definitions.
@@ -801,6 +823,17 @@ bool SchedulerService::is_canonical(ObjectID objectid) {
return objectid == (*target_objectids)[objectid];
}
void SchedulerService::perform_functions_to_run() {
auto workers = GET(workers_);
auto function_to_run_queue = GET(function_to_run_queue_);
auto exported_functions_to_run = GET(exported_functions_to_run_);
while (!function_to_run_queue->empty()) {
std::pair<WorkerId, int> workerid_functionid_pair = function_to_run_queue->front();
export_function_to_run_to_worker(workerid_functionid_pair.first, workerid_functionid_pair.second, workers, exported_functions_to_run);
function_to_run_queue->pop();
}
}
void SchedulerService::perform_remote_function_exports() {
auto workers = GET(workers_);
auto remote_function_export_queue = GET(remote_function_export_queue_);
@@ -1084,22 +1117,39 @@ void SchedulerService::get_equivalent_objectids(ObjectID objectid, std::vector<O
}
void SchedulerService::export_function_to_run_to_worker(WorkerId workerid, int function_index, MySynchronizedPtr<std::vector<WorkerHandle> > &workers, const MySynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions_to_run) {
RAY_LOG(RAY_INFO, "exporting function to run with index " << function_index << " to worker " << workerid);
ClientContext context;
RunFunctionOnWorkerRequest request;
request.mutable_function()->CopyFrom(*(*exported_functions_to_run)[function_index].get());
AckReply reply;
RAY_CHECK_GRPC((*workers)[workerid].worker_stub->RunFunctionOnWorker(&context, request, &reply));
}
void SchedulerService::export_function_to_worker(WorkerId workerid, int function_index, MySynchronizedPtr<std::vector<WorkerHandle> > &workers, const MySynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions) {
RAY_LOG(RAY_INFO, "exporting function with index " << function_index << " to worker " << workerid);
ClientContext import_context;
ImportRemoteFunctionRequest import_request;
import_request.mutable_function()->CopyFrom(*(*exported_functions)[function_index].get());
AckReply import_reply;
RAY_CHECK_GRPC((*workers)[workerid].worker_stub->ImportRemoteFunction(&import_context, import_request, &import_reply));
RAY_LOG(RAY_INFO, "exporting remote function with index " << function_index << " to worker " << workerid);
ClientContext context;
ImportRemoteFunctionRequest request;
request.mutable_function()->CopyFrom(*(*exported_functions)[function_index].get());
AckReply reply;
RAY_CHECK_GRPC((*workers)[workerid].worker_stub->ImportRemoteFunction(&context, request, &reply));
}
void SchedulerService::export_reusable_variable_to_worker(WorkerId workerid, int reusable_variable_index, MySynchronizedPtr<std::vector<WorkerHandle> > &workers, const MySynchronizedPtr<std::vector<std::unique_ptr<ReusableVar> > > &exported_reusable_variables) {
RAY_LOG(RAY_INFO, "exporting reusable variable with index " << reusable_variable_index << " to worker " << workerid);
ClientContext import_context;
ImportReusableVariableRequest import_request;
import_request.mutable_reusable_variable()->CopyFrom(*(*exported_reusable_variables)[reusable_variable_index].get());
AckReply import_reply;
RAY_CHECK_GRPC((*workers)[workerid].worker_stub->ImportReusableVariable(&import_context, import_request, &import_reply));
ClientContext context;
ImportReusableVariableRequest request;
request.mutable_reusable_variable()->CopyFrom(*(*exported_reusable_variables)[reusable_variable_index].get());
AckReply reply;
RAY_CHECK_GRPC((*workers)[workerid].worker_stub->ImportReusableVariable(&context, request, &reply));
}
void SchedulerService::add_all_functions_to_run_to_worker_queue(WorkerId workerid) {
auto function_to_run_queue = GET(function_to_run_queue_);
auto exported_functions_to_run = GET(exported_functions_to_run_);
for (int i = 0; i < exported_functions_to_run->size(); ++i) {
function_to_run_queue->push(std::make_pair(workerid, i));
}
}
void SchedulerService::add_all_remote_functions_to_worker_export_queue(WorkerId workerid) {
+16
View File
@@ -75,6 +75,7 @@ public:
Status SchedulerInfo(ServerContext* context, const SchedulerInfoRequest* request, SchedulerInfoReply* reply) override;
Status TaskInfo(ServerContext* context, const TaskInfoRequest* request, TaskInfoReply* reply) override;
Status KillWorkers(ServerContext* context, const KillWorkersRequest* request, KillWorkersReply* reply) override;
Status RunFunctionOnAllWorkers(ServerContext* context, const RunFunctionOnAllWorkersRequest* request, AckReply* reply) override;
Status ExportRemoteFunction(ServerContext* context, const ExportRemoteFunctionRequest* request, AckReply* reply) override;
Status ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) override;
Status NotifyFailure(ServerContext*, const NotifyFailureRequest* request, AckReply* reply) override;
@@ -119,10 +120,13 @@ private:
ObjStoreId pick_objstore(ObjectID objectid);
// checks if objectid is a canonical objectid
bool is_canonical(ObjectID objectid);
// Export all queued up functions to run.
void perform_functions_to_run();
// Export all queued up remote functions.
void perform_remote_function_exports();
// Export all queued up reusable variables.
void perform_reusable_variable_exports();
// Perform all queued up gets that can be performed.
void perform_gets();
// schedule tasks using the naive algorithm
void schedule_tasks_naively();
@@ -148,10 +152,15 @@ private:
void upstream_objectids(ObjectID objectid, std::vector<ObjectID> &objectids, const MySynchronizedPtr<std::vector<std::vector<ObjectID> > > &reverse_target_objectids);
// Find all of the object IDs that refer to the same object as objectid (as best as we can determine at the moment). The information may be incomplete because not all of the aliases may be known.
void get_equivalent_objectids(ObjectID objectid, std::vector<ObjectID> &equivalent_objectids);
// Export a function to run to a worker.
void export_function_to_run_to_worker(WorkerId workerid, int function_index, MySynchronizedPtr<std::vector<WorkerHandle> > &workers, const MySynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions_to_run);
// Export a remote function to a worker.
void export_function_to_worker(WorkerId workerid, int function_index, MySynchronizedPtr<std::vector<WorkerHandle> > &workers, const MySynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions);
// Export a reusable variable to a worker
void export_reusable_variable_to_worker(WorkerId workerid, int reusable_variable_index, MySynchronizedPtr<std::vector<WorkerHandle> > &workers, const MySynchronizedPtr<std::vector<std::unique_ptr<ReusableVar> > > &exported_reusable_variables);
// Add to the function to run export queue the job of exporting all functions
// to run to the given worker. This is used when a new worker registers.
void add_all_functions_to_run_to_worker_queue(WorkerId workerid);
// Add to the remote function export queue the job of exporting all remote
// functions to the given worker. This is used when a new worker registers.
void add_all_remote_functions_to_worker_export_queue(WorkerId workerid);
@@ -226,6 +235,11 @@ private:
// lock (objects_lock_). // TODO(rkn): Consider making this part of the
// objtable data structure.
std::vector<std::vector<ObjectID> > objects_in_transit_;
// List of pending functions to run on workers. These should be processed in a
// first in first out manner. The first element of each pair is the ID of the
// worker to run the function on, and the second element of each pair is the
// index of the function to run.
Synchronized<std::queue<std::pair<WorkerId, int> > > function_to_run_queue_;
// List of pending remote function exports. These should be processed in a
// first in first out manner. The first element of each pair is the ID of the
// worker to export the remote function to, and the second element of each
@@ -236,6 +250,8 @@ private:
// worker to export the reusable variable to, and the second element of each
// pair is the index of the reusable variable to export.
Synchronized<std::queue<std::pair<WorkerId, int> > > reusable_variable_export_queue_;
// All of the functions that have been exported to the workers to run.
Synchronized<std::vector<std::unique_ptr<Function> > > exported_functions_to_run_;
// All of the remote functions that have been exported to the workers.
Synchronized<std::vector<std::unique_ptr<Function> > > exported_functions_;
// All of the reusable variables that have been exported to the workers.
+26
View File
@@ -26,6 +26,21 @@ Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskR
WorkerMessage* message_ptr = message.get();
RAY_CHECK(send_queue_.send(&message_ptr), "Failed to send message from the worker service to the worker because the message queue was full.");
}
// The message will get deleted in receive_next_message().
message.release();
return Status::OK;
}
Status WorkerServiceImpl::RunFunctionOnWorker(ServerContext* context, const RunFunctionOnWorkerRequest* request, AckReply* reply) {
RAY_CHECK(mode_ == Mode::WORKER_MODE, "RunFunctionOnWorker can only be called on workers.");
std::unique_ptr<WorkerMessage> message(new WorkerMessage());
message->mutable_function_to_run()->CopyFrom(request->function());
RAY_LOG(RAY_INFO, "Running function on worker.");
{
WorkerMessage* message_ptr = message.get();
RAY_CHECK(send_queue_.send(&message_ptr), "Failed to send message from the worker service to the worker because the message queue was full.");
}
// The message will get deleted in receive_next_message().
message.release();
return Status::OK;
}
@@ -39,6 +54,7 @@ Status WorkerServiceImpl::ImportRemoteFunction(ServerContext* context, const Imp
WorkerMessage* message_ptr = message.get();
RAY_CHECK(send_queue_.send(&message_ptr), "Failed to send message from the worker service to the worker because the message queue was full.");
}
// The message will get deleted in receive_next_message().
message.release();
return Status::OK;
}
@@ -52,6 +68,7 @@ Status WorkerServiceImpl::ImportReusableVariable(ServerContext* context, const I
WorkerMessage* message_ptr = message.get();
RAY_CHECK(send_queue_.send(&message_ptr), "Failed to send message from the worker service to the worker because the message queue was full.");
}
// The message will get deleted in receive_next_message().
message.release();
return Status::OK;
}
@@ -440,6 +457,15 @@ std::vector<int> Worker::select(std::vector<ObjectID>& objectids) {
return result;
}
void Worker::run_function_on_all_workers(const std::string& function) {
RAY_CHECK(connected_, "Attempted to run function on all workers but failed.");
ClientContext context;
RunFunctionOnAllWorkersRequest request;
request.mutable_function()->set_implementation(function);
AckReply reply;
RAY_CHECK_GRPC(scheduler_stub_->RunFunctionOnAllWorkers(&context, request, &reply));
}
bool Worker::export_remote_function(const std::string& function_name, const std::string& function) {
RAY_CHECK(connected_, "Attempted to export function but failed.");
ClientContext context;
+3
View File
@@ -32,6 +32,7 @@ class WorkerServiceImpl final : public WorkerService::Service {
public:
WorkerServiceImpl(const std::string& worker_address, Mode mode);
Status ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, AckReply* reply) override;
Status RunFunctionOnWorker(ServerContext* context, const RunFunctionOnWorkerRequest* request, AckReply* reply) override;
Status ImportRemoteFunction(ServerContext* context, const ImportRemoteFunctionRequest* request, AckReply* reply) override;
Status Die(ServerContext* context, const DieRequest* request, AckReply* reply) override;
Status ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) override;
@@ -104,6 +105,8 @@ class Worker {
void task_info(ClientContext &context, TaskInfoRequest &request, TaskInfoReply &reply);
// gets indices of available objects
std::vector<int> select(std::vector<ObjectID>& objectids);
// Export a function to be run on all workers.
void run_function_on_all_workers(const std::string& function);
// export function to workers
bool export_remote_function(const std::string& function_name, const std::string& function);
// export reusable variable to workers
+17
View File
@@ -351,6 +351,23 @@ class APITest(unittest.TestCase):
ray.worker.cleanup()
def testRunningFunctionOnAllWorkers(self):
ray.init(start_ray_local=True, num_workers=1)
def f():
sys.path.append("fake_directory")
ray.worker.global_worker.run_function_on_all_workers(f)
@ray.remote([], [list])
def get_path():
return sys.path
self.assertEqual("fake_directory", ray.get(get_path.remote())[-1])
def f():
sys.path.pop(-1)
ray.worker.global_worker.run_function_on_all_workers(f)
self.assertTrue("fake_directory" not in ray.get(get_path.remote()))
ray.worker.cleanup()
def check_get_deallocated(data):
x = ray.put(data)
ray.get(x)