Automatically add relevant directories to Python paths of workers (#380)

* Make ray.init set python paths of workers.

* Decouple starting cluster from copying user source code

* also add current directory to path

* Add comments about deallocation.

* Add test for new code path.
This commit is contained in:
Robert Nishihara
2016-08-16 14:53:55 -07:00
committed by Philipp Moritz
parent 7246013008
commit e06311d415
12 changed files with 222 additions and 64 deletions
+3 -13
View File
@@ -82,7 +82,7 @@ def start_objstore(scheduler_address, node_ip_address, cleanup):
if cleanup:
all_processes.append(p)
def start_worker(node_ip_address, worker_path, scheduler_address, objstore_address=None, cleanup=True, user_source_directory=None):
def start_worker(node_ip_address, worker_path, scheduler_address, objstore_address=None, cleanup=True):
"""This method starts a worker process.
Args:
@@ -96,18 +96,10 @@ def start_worker(node_ip_address, worker_path, scheduler_address, objstore_addre
cleanup (Optional[bool]): True if using Ray in local mode. If cleanup is
true, then this process will be killed by serices.cleanup() when the
Python process that imported services exits. This is True by default.
user_source_directory (Optional[str]): The directory containing the
application code. This directory will be added to the path of each worker.
If not provided, the directory of the script currently being run is used.
"""
if user_source_directory is None:
# This extracts the directory of the script that is currently being run.
# This will allow users to import modules contained in this directory.
user_source_directory = os.path.dirname(os.path.abspath(os.path.join(os.path.curdir, sys.argv[0])))
command = ["python",
worker_path,
"--node-ip-address=" + node_ip_address,
"--user-source-directory=" + user_source_directory,
"--scheduler-address=" + scheduler_address]
if objstore_address is not None:
command.append("--objstore-address=" + objstore_address)
@@ -115,7 +107,7 @@ def start_worker(node_ip_address, worker_path, scheduler_address, objstore_addre
if cleanup:
all_processes.append(p)
def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None, user_source_directory=None, cleanup=False):
def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None, cleanup=False):
"""Start an object store and associated workers in the cluster setting.
This starts an object store and the associated workers when Ray is being used
@@ -129,8 +121,6 @@ def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None
num_workers (int): The number of workers to be started on this node.
worker_path (str): Path of the Python worker script that will be run on the
worker.
user_source_directory (str): Path to the user's code the workers will import
modules from.
cleanup (bool): If cleanup is True, then the processes started by this
command will be killed when the process that imported services exits.
"""
@@ -139,7 +129,7 @@ def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None
if worker_path is None:
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../scripts/default_worker.py")
for _ in range(num_workers):
start_worker(node_ip_address, worker_path, scheduler_address, user_source_directory=user_source_directory, cleanup=cleanup)
start_worker(node_ip_address, worker_path, scheduler_address, cleanup=cleanup)
time.sleep(0.5)
def start_workers(scheduler_address, objstore_address, num_workers, worker_path):
+48
View File
@@ -1,4 +1,5 @@
import os
import sys
import time
import traceback
import copy
@@ -516,6 +517,23 @@ class Worker(object):
objectids = raylib.submit_task(self.handle, task_capsule)
return objectids
def run_function_on_all_workers(self, function):
"""Run arbitrary code on all of the workers.
This function will first be run on the driver, and then it will be exported
to all of the workers to be run. It will also be run on any new workers that
register later.
Args:
function (Callable): The function to run on all of the workers. It should
not take any arguments. If it returns anything, its return values will
not be used.
"""
# First run the function on the driver.
function()
# Then run the function on all of the workers.
raylib.run_function_on_all_workers(self.handle, pickling.dumps(function))
global_worker = Worker()
"""Worker: The global Worker object for this worker process.
@@ -760,8 +778,18 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, worker=gl
_logger().setLevel(logging.DEBUG)
_logger().propagate = False
if mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
# Add the directory containing the script that is running to the Python
# paths of the workers. Also add the current directory. Note that this
# assumes that the directory structures on the machines in the clusters are
# the same.
script_directory = os.path.abspath(os.path.dirname(sys.argv[0]))
current_directory = os.path.abspath(os.path.curdir)
worker.run_function_on_all_workers(lambda : sys.path.insert(1, script_directory))
worker.run_function_on_all_workers(lambda : sys.path.insert(1, current_directory))
# Export cached remote functions to the workers.
for function_name, function_to_export in worker.cached_remote_functions:
raylib.export_remote_function(worker.handle, function_name, function_to_export)
# Export cached reusable variables to the workers.
for name, reusable_variable in reusables._cached_reusables:
_export_reusable_variable(name, reusable_variable)
worker.cached_remote_functions = None
@@ -998,6 +1026,23 @@ def main_loop(worker=global_worker):
else:
_logger().info("Successfully imported reusable variable {}.".format(reusable_variable_name))
def process_function_to_run(serialized_function):
"""Run on arbitrary function on the worker."""
try:
# Deserialize the function.
function = pickling.loads(serialized_function)
# Run the function.
function()
except:
# If an exception was thrown when the function was run, we record the
# traceback and notify the scheduler of the failure.
traceback_str = format_error_message(traceback.format_exc())
_logger().info("Failed to run function on worker. Failed with message: \n\n{}\n".format(traceback_str))
# Notify the scheduler that running the function failed.
# TODO(rkn): Notify the scheduler.
else:
_logger().info("Successfully ran function on worker.")
while True:
command, command_args = raylib.wait_for_next_message(worker.handle)
try:
@@ -1013,6 +1058,9 @@ def main_loop(worker=global_worker):
elif command == "reusable_variable":
name, initializer_str, reinitializer_str = command_args
process_reusable_variable(name, initializer_str, reinitializer_str)
elif command == "function_to_run":
serialized_function = command_args
process_function_to_run(serialized_function)
else:
_logger().info("Reached the end of the if-else loop in the main loop. This should be unreachable.")
assert False, "This code should be unreachable."