Give run_function_on_all_workers to take a worker_info dictionary including a counter. (#149)

* Suppress Redis warnings and remove some global scheduler logging.

* Pass a counter into run_function_on_all_workers indicating how many workers have begun executing this function.
This commit is contained in:
Robert Nishihara
2016-12-22 22:05:58 -08:00
committed by Philipp Moritz
parent 92010ca5b5
commit 86b211f5c2
5 changed files with 61 additions and 22 deletions
+4 -2
View File
@@ -360,8 +360,10 @@ def start_ray_local(node_ip_address="127.0.0.1", num_workers=0, num_local_schedu
"""
if worker_path is None:
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "workers/default_worker.py")
# Start Redis.
redis_port = start_redis(cleanup=cleanup, redirect_output=redirect_output)
# Start Redis. TODO(rkn): We are suppressing the output of Redis because on
# Linux it prints a bunch of warning messages when it starts up. Instead of
# suppressing the output, we should address the warnings.
redis_port = start_redis(cleanup=cleanup, redirect_output=True)
redis_address = address(node_ip_address, redis_port)
time.sleep(0.1)
# Start the global scheduler.
+18 -9
View File
@@ -523,11 +523,14 @@ class Worker(object):
if self.mode is None:
self.cached_functions_to_run.append(function)
else:
# First run the function on the driver.
function(self)
# Run the function on all workers.
function_to_run_id = random_string()
key = "FunctionsToRun:{}".format(function_to_run_id)
# First run the function on the driver. Pass in the number of workers on
# this node that have already started executing this remote function,
# and increment that value. Subtract 1 so that the counter starts at 0.
counter = self.redis_client.hincrby(self.node_ip_address, key, 1) - 1
function({"counter": counter})
# Run the function on all workers.
self.redis_client.hmset(key, {"function_id": function_to_run_id,
"function": pickling.dumps(function)})
self.redis_client.rpush("Exports", key)
@@ -781,7 +784,7 @@ def print_error_messages(worker):
with worker.lock:
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
for error_key in error_keys:
error_message = worker.redis_client.hget(error_key, "message")
error_message = worker.redis_client.hget(error_key, "message").decode("ascii")
print(error_message)
num_errors_printed += 1
@@ -789,7 +792,7 @@ def print_error_messages(worker):
for msg in worker.error_message_pubsub_client.listen():
with worker.lock:
for error_key in worker.redis_client.lrange("ErrorKeys", num_errors_printed, -1):
error_message = worker.redis_client.hget(error_key, "message")
error_message = worker.redis_client.hget(error_key, "message").decode("ascii")
print(error_message)
num_errors_printed += 1
except redis.ConnectionError:
@@ -849,11 +852,15 @@ def fetch_and_register_reusable_variable(key, worker=global_worker):
def fetch_and_execute_function_to_run(key, worker=global_worker):
"""Run on arbitrary function on the worker."""
serialized_function, = worker.redis_client.hmget(key, ["function"])
# Get the number of workers on this node that have already started executing
# this remote function, and increment that value. Subtract 1 so the counter
# starts at 0.
counter = worker.redis_client.hincrby(worker.node_ip_address, key, 1) - 1
try:
# Deserialize the function.
function = pickling.loads(serialized_function)
# Run the function.
function(worker)
function({"counter": counter})
except:
# If an exception was thrown when the function was run, we record the
# traceback and notify the scheduler of the failure.
@@ -934,6 +941,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker):
# or to start the worker service.
if mode == PYTHON_MODE:
return
# Set the node IP address.
worker.node_ip_address = info["node_ip_address"]
# Create a Redis client.
redis_host, redis_port = info["redis_address"].split(":")
worker.redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
@@ -994,8 +1003,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker):
# the same.
script_directory = os.path.abspath(os.path.dirname(sys.argv[0]))
current_directory = os.path.abspath(os.path.curdir)
worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, script_directory))
worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, current_directory))
worker.run_function_on_all_workers(lambda worker_info: sys.path.insert(1, script_directory))
worker.run_function_on_all_workers(lambda worker_info: sys.path.insert(1, current_directory))
# TODO(rkn): Here we first export functions to run, then reusable variables,
# then remote functions. The order matters. For example, one of the
# functions to run may set the Python path, which is needed to import a
@@ -1054,7 +1063,7 @@ def register_class(cls, pickle=False, worker=global_worker):
# Raise an exception if cls cannot be serialized efficiently by Ray.
if not pickle:
serialization.check_serializable(cls)
def register_class_for_serialization(worker):
def register_class_for_serialization(worker_info):
serialization.add_class_to_whitelist(cls, pickle=pickle)
worker.run_function_on_all_workers(register_class_for_serialization)
+2 -1
View File
@@ -21,7 +21,8 @@ def random_string():
if __name__ == "__main__":
args = parser.parse_args()
info = {"redis_address": args.redis_address,
info = {"node_ip_address": args.node_ip_address,
"redis_address": args.redis_address,
"store_socket_name": args.object_store_name,
"manager_socket_name": args.object_store_manager_name,
"local_scheduler_socket_name": args.local_scheduler_name}