mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 01:27:43 +08:00
Give run_function_on_all_workers to take a worker_info dictionary including a counter. (#149)
* Suppress Redis warnings and remove some global scheduler logging. * Pass a counter into run_function_on_all_workers indicating how many workers have begun executing this function.
This commit is contained in:
committed by
Philipp Moritz
parent
92010ca5b5
commit
86b211f5c2
@@ -360,8 +360,10 @@ def start_ray_local(node_ip_address="127.0.0.1", num_workers=0, num_local_schedu
|
||||
"""
|
||||
if worker_path is None:
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "workers/default_worker.py")
|
||||
# Start Redis.
|
||||
redis_port = start_redis(cleanup=cleanup, redirect_output=redirect_output)
|
||||
# Start Redis. TODO(rkn): We are suppressing the output of Redis because on
|
||||
# Linux it prints a bunch of warning messages when it starts up. Instead of
|
||||
# suppressing the output, we should address the warnings.
|
||||
redis_port = start_redis(cleanup=cleanup, redirect_output=True)
|
||||
redis_address = address(node_ip_address, redis_port)
|
||||
time.sleep(0.1)
|
||||
# Start the global scheduler.
|
||||
|
||||
@@ -523,11 +523,14 @@ class Worker(object):
|
||||
if self.mode is None:
|
||||
self.cached_functions_to_run.append(function)
|
||||
else:
|
||||
# First run the function on the driver.
|
||||
function(self)
|
||||
# Run the function on all workers.
|
||||
function_to_run_id = random_string()
|
||||
key = "FunctionsToRun:{}".format(function_to_run_id)
|
||||
# First run the function on the driver. Pass in the number of workers on
|
||||
# this node that have already started executing this remote function,
|
||||
# and increment that value. Subtract 1 so that the counter starts at 0.
|
||||
counter = self.redis_client.hincrby(self.node_ip_address, key, 1) - 1
|
||||
function({"counter": counter})
|
||||
# Run the function on all workers.
|
||||
self.redis_client.hmset(key, {"function_id": function_to_run_id,
|
||||
"function": pickling.dumps(function)})
|
||||
self.redis_client.rpush("Exports", key)
|
||||
@@ -781,7 +784,7 @@ def print_error_messages(worker):
|
||||
with worker.lock:
|
||||
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
|
||||
for error_key in error_keys:
|
||||
error_message = worker.redis_client.hget(error_key, "message")
|
||||
error_message = worker.redis_client.hget(error_key, "message").decode("ascii")
|
||||
print(error_message)
|
||||
num_errors_printed += 1
|
||||
|
||||
@@ -789,7 +792,7 @@ def print_error_messages(worker):
|
||||
for msg in worker.error_message_pubsub_client.listen():
|
||||
with worker.lock:
|
||||
for error_key in worker.redis_client.lrange("ErrorKeys", num_errors_printed, -1):
|
||||
error_message = worker.redis_client.hget(error_key, "message")
|
||||
error_message = worker.redis_client.hget(error_key, "message").decode("ascii")
|
||||
print(error_message)
|
||||
num_errors_printed += 1
|
||||
except redis.ConnectionError:
|
||||
@@ -849,11 +852,15 @@ def fetch_and_register_reusable_variable(key, worker=global_worker):
|
||||
def fetch_and_execute_function_to_run(key, worker=global_worker):
|
||||
"""Run on arbitrary function on the worker."""
|
||||
serialized_function, = worker.redis_client.hmget(key, ["function"])
|
||||
# Get the number of workers on this node that have already started executing
|
||||
# this remote function, and increment that value. Subtract 1 so the counter
|
||||
# starts at 0.
|
||||
counter = worker.redis_client.hincrby(worker.node_ip_address, key, 1) - 1
|
||||
try:
|
||||
# Deserialize the function.
|
||||
function = pickling.loads(serialized_function)
|
||||
# Run the function.
|
||||
function(worker)
|
||||
function({"counter": counter})
|
||||
except:
|
||||
# If an exception was thrown when the function was run, we record the
|
||||
# traceback and notify the scheduler of the failure.
|
||||
@@ -934,6 +941,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker):
|
||||
# or to start the worker service.
|
||||
if mode == PYTHON_MODE:
|
||||
return
|
||||
# Set the node IP address.
|
||||
worker.node_ip_address = info["node_ip_address"]
|
||||
# Create a Redis client.
|
||||
redis_host, redis_port = info["redis_address"].split(":")
|
||||
worker.redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
|
||||
@@ -994,8 +1003,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker):
|
||||
# the same.
|
||||
script_directory = os.path.abspath(os.path.dirname(sys.argv[0]))
|
||||
current_directory = os.path.abspath(os.path.curdir)
|
||||
worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, script_directory))
|
||||
worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, current_directory))
|
||||
worker.run_function_on_all_workers(lambda worker_info: sys.path.insert(1, script_directory))
|
||||
worker.run_function_on_all_workers(lambda worker_info: sys.path.insert(1, current_directory))
|
||||
# TODO(rkn): Here we first export functions to run, then reusable variables,
|
||||
# then remote functions. The order matters. For example, one of the
|
||||
# functions to run may set the Python path, which is needed to import a
|
||||
@@ -1054,7 +1063,7 @@ def register_class(cls, pickle=False, worker=global_worker):
|
||||
# Raise an exception if cls cannot be serialized efficiently by Ray.
|
||||
if not pickle:
|
||||
serialization.check_serializable(cls)
|
||||
def register_class_for_serialization(worker):
|
||||
def register_class_for_serialization(worker_info):
|
||||
serialization.add_class_to_whitelist(cls, pickle=pickle)
|
||||
worker.run_function_on_all_workers(register_class_for_serialization)
|
||||
|
||||
|
||||
@@ -21,7 +21,8 @@ def random_string():
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
info = {"redis_address": args.redis_address,
|
||||
info = {"node_ip_address": args.node_ip_address,
|
||||
"redis_address": args.redis_address,
|
||||
"store_socket_name": args.object_store_name,
|
||||
"manager_socket_name": args.object_store_manager_name,
|
||||
"local_scheduler_socket_name": args.local_scheduler_name}
|
||||
|
||||
Reference in New Issue
Block a user