mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 10:45:02 +08:00
Automatically add relevant directories to Python paths of workers (#380)
* Make ray.init set python paths of workers. * Decouple starting cluster from copying user source code * also add current directory to path * Add comments about deallocation. * Add test for new code path.
This commit is contained in:
committed by
Philipp Moritz
parent
7246013008
commit
e06311d415
@@ -82,7 +82,7 @@ def start_objstore(scheduler_address, node_ip_address, cleanup):
|
||||
if cleanup:
|
||||
all_processes.append(p)
|
||||
|
||||
def start_worker(node_ip_address, worker_path, scheduler_address, objstore_address=None, cleanup=True, user_source_directory=None):
|
||||
def start_worker(node_ip_address, worker_path, scheduler_address, objstore_address=None, cleanup=True):
|
||||
"""This method starts a worker process.
|
||||
|
||||
Args:
|
||||
@@ -96,18 +96,10 @@ def start_worker(node_ip_address, worker_path, scheduler_address, objstore_addre
|
||||
cleanup (Optional[bool]): True if using Ray in local mode. If cleanup is
|
||||
true, then this process will be killed by serices.cleanup() when the
|
||||
Python process that imported services exits. This is True by default.
|
||||
user_source_directory (Optional[str]): The directory containing the
|
||||
application code. This directory will be added to the path of each worker.
|
||||
If not provided, the directory of the script currently being run is used.
|
||||
"""
|
||||
if user_source_directory is None:
|
||||
# This extracts the directory of the script that is currently being run.
|
||||
# This will allow users to import modules contained in this directory.
|
||||
user_source_directory = os.path.dirname(os.path.abspath(os.path.join(os.path.curdir, sys.argv[0])))
|
||||
command = ["python",
|
||||
worker_path,
|
||||
"--node-ip-address=" + node_ip_address,
|
||||
"--user-source-directory=" + user_source_directory,
|
||||
"--scheduler-address=" + scheduler_address]
|
||||
if objstore_address is not None:
|
||||
command.append("--objstore-address=" + objstore_address)
|
||||
@@ -115,7 +107,7 @@ def start_worker(node_ip_address, worker_path, scheduler_address, objstore_addre
|
||||
if cleanup:
|
||||
all_processes.append(p)
|
||||
|
||||
def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None, user_source_directory=None, cleanup=False):
|
||||
def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None, cleanup=False):
|
||||
"""Start an object store and associated workers in the cluster setting.
|
||||
|
||||
This starts an object store and the associated workers when Ray is being used
|
||||
@@ -129,8 +121,6 @@ def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None
|
||||
num_workers (int): The number of workers to be started on this node.
|
||||
worker_path (str): Path of the Python worker script that will be run on the
|
||||
worker.
|
||||
user_source_directory (str): Path to the user's code the workers will import
|
||||
modules from.
|
||||
cleanup (bool): If cleanup is True, then the processes started by this
|
||||
command will be killed when the process that imported services exits.
|
||||
"""
|
||||
@@ -139,7 +129,7 @@ def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None
|
||||
if worker_path is None:
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../scripts/default_worker.py")
|
||||
for _ in range(num_workers):
|
||||
start_worker(node_ip_address, worker_path, scheduler_address, user_source_directory=user_source_directory, cleanup=cleanup)
|
||||
start_worker(node_ip_address, worker_path, scheduler_address, cleanup=cleanup)
|
||||
time.sleep(0.5)
|
||||
|
||||
def start_workers(scheduler_address, objstore_address, num_workers, worker_path):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import copy
|
||||
@@ -516,6 +517,23 @@ class Worker(object):
|
||||
objectids = raylib.submit_task(self.handle, task_capsule)
|
||||
return objectids
|
||||
|
||||
def run_function_on_all_workers(self, function):
|
||||
"""Run arbitrary code on all of the workers.
|
||||
|
||||
This function will first be run on the driver, and then it will be exported
|
||||
to all of the workers to be run. It will also be run on any new workers that
|
||||
register later.
|
||||
|
||||
Args:
|
||||
function (Callable): The function to run on all of the workers. It should
|
||||
not take any arguments. If it returns anything, its return values will
|
||||
not be used.
|
||||
"""
|
||||
# First run the function on the driver.
|
||||
function()
|
||||
# Then run the function on all of the workers.
|
||||
raylib.run_function_on_all_workers(self.handle, pickling.dumps(function))
|
||||
|
||||
global_worker = Worker()
|
||||
"""Worker: The global Worker object for this worker process.
|
||||
|
||||
@@ -760,8 +778,18 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, worker=gl
|
||||
_logger().setLevel(logging.DEBUG)
|
||||
_logger().propagate = False
|
||||
if mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
# Add the directory containing the script that is running to the Python
|
||||
# paths of the workers. Also add the current directory. Note that this
|
||||
# assumes that the directory structures on the machines in the clusters are
|
||||
# the same.
|
||||
script_directory = os.path.abspath(os.path.dirname(sys.argv[0]))
|
||||
current_directory = os.path.abspath(os.path.curdir)
|
||||
worker.run_function_on_all_workers(lambda : sys.path.insert(1, script_directory))
|
||||
worker.run_function_on_all_workers(lambda : sys.path.insert(1, current_directory))
|
||||
# Export cached remote functions to the workers.
|
||||
for function_name, function_to_export in worker.cached_remote_functions:
|
||||
raylib.export_remote_function(worker.handle, function_name, function_to_export)
|
||||
# Export cached reusable variables to the workers.
|
||||
for name, reusable_variable in reusables._cached_reusables:
|
||||
_export_reusable_variable(name, reusable_variable)
|
||||
worker.cached_remote_functions = None
|
||||
@@ -998,6 +1026,23 @@ def main_loop(worker=global_worker):
|
||||
else:
|
||||
_logger().info("Successfully imported reusable variable {}.".format(reusable_variable_name))
|
||||
|
||||
def process_function_to_run(serialized_function):
|
||||
"""Run on arbitrary function on the worker."""
|
||||
try:
|
||||
# Deserialize the function.
|
||||
function = pickling.loads(serialized_function)
|
||||
# Run the function.
|
||||
function()
|
||||
except:
|
||||
# If an exception was thrown when the function was run, we record the
|
||||
# traceback and notify the scheduler of the failure.
|
||||
traceback_str = format_error_message(traceback.format_exc())
|
||||
_logger().info("Failed to run function on worker. Failed with message: \n\n{}\n".format(traceback_str))
|
||||
# Notify the scheduler that running the function failed.
|
||||
# TODO(rkn): Notify the scheduler.
|
||||
else:
|
||||
_logger().info("Successfully ran function on worker.")
|
||||
|
||||
while True:
|
||||
command, command_args = raylib.wait_for_next_message(worker.handle)
|
||||
try:
|
||||
@@ -1013,6 +1058,9 @@ def main_loop(worker=global_worker):
|
||||
elif command == "reusable_variable":
|
||||
name, initializer_str, reinitializer_str = command_args
|
||||
process_reusable_variable(name, initializer_str, reinitializer_str)
|
||||
elif command == "function_to_run":
|
||||
serialized_function = command_args
|
||||
process_function_to_run(serialized_function)
|
||||
else:
|
||||
_logger().info("Reached the end of the if-else loop in the main loop. This should be unreachable.")
|
||||
assert False, "This code should be unreachable."
|
||||
|
||||
Reference in New Issue
Block a user