From 86b211f5c2bf973ccee1b5a7f8c323846254bbbf Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Thu, 22 Dec 2016 22:05:58 -0800 Subject: [PATCH] Give run_function_on_all_workers to take a worker_info dictionary including a counter. (#149) * Suppress Redis warnings and remove some global scheduler logging. * Pass a counter into run_function_on_all_workers indicating how many workers have begun executing this function. --- lib/python/ray/services.py | 6 ++- lib/python/ray/worker.py | 27 ++++++++---- lib/python/ray/workers/default_worker.py | 3 +- .../global_scheduler_algorithm.c | 6 +-- test/runtest.py | 41 +++++++++++++++---- 5 files changed, 61 insertions(+), 22 deletions(-) diff --git a/lib/python/ray/services.py b/lib/python/ray/services.py index fc9b9c034..9fd439c15 100644 --- a/lib/python/ray/services.py +++ b/lib/python/ray/services.py @@ -360,8 +360,10 @@ def start_ray_local(node_ip_address="127.0.0.1", num_workers=0, num_local_schedu """ if worker_path is None: worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "workers/default_worker.py") - # Start Redis. - redis_port = start_redis(cleanup=cleanup, redirect_output=redirect_output) + # Start Redis. TODO(rkn): We are suppressing the output of Redis because on + # Linux it prints a bunch of warning messages when it starts up. Instead of + # suppressing the output, we should address the warnings. + redis_port = start_redis(cleanup=cleanup, redirect_output=True) redis_address = address(node_ip_address, redis_port) time.sleep(0.1) # Start the global scheduler. diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index bd880ecd1..880b2ce98 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -523,11 +523,14 @@ class Worker(object): if self.mode is None: self.cached_functions_to_run.append(function) else: - # First run the function on the driver. - function(self) - # Run the function on all workers. function_to_run_id = random_string() key = "FunctionsToRun:{}".format(function_to_run_id) + # First run the function on the driver. Pass in the number of workers on + # this node that have already started executing this remote function, + # and increment that value. Subtract 1 so that the counter starts at 0. + counter = self.redis_client.hincrby(self.node_ip_address, key, 1) - 1 + function({"counter": counter}) + # Run the function on all workers. self.redis_client.hmset(key, {"function_id": function_to_run_id, "function": pickling.dumps(function)}) self.redis_client.rpush("Exports", key) @@ -781,7 +784,7 @@ def print_error_messages(worker): with worker.lock: error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1) for error_key in error_keys: - error_message = worker.redis_client.hget(error_key, "message") + error_message = worker.redis_client.hget(error_key, "message").decode("ascii") print(error_message) num_errors_printed += 1 @@ -789,7 +792,7 @@ def print_error_messages(worker): for msg in worker.error_message_pubsub_client.listen(): with worker.lock: for error_key in worker.redis_client.lrange("ErrorKeys", num_errors_printed, -1): - error_message = worker.redis_client.hget(error_key, "message") + error_message = worker.redis_client.hget(error_key, "message").decode("ascii") print(error_message) num_errors_printed += 1 except redis.ConnectionError: @@ -849,11 +852,15 @@ def fetch_and_register_reusable_variable(key, worker=global_worker): def fetch_and_execute_function_to_run(key, worker=global_worker): """Run on arbitrary function on the worker.""" serialized_function, = worker.redis_client.hmget(key, ["function"]) + # Get the number of workers on this node that have already started executing + # this remote function, and increment that value. Subtract 1 so the counter + # starts at 0. + counter = worker.redis_client.hincrby(worker.node_ip_address, key, 1) - 1 try: # Deserialize the function. function = pickling.loads(serialized_function) # Run the function. - function(worker) + function({"counter": counter}) except: # If an exception was thrown when the function was run, we record the # traceback and notify the scheduler of the failure. @@ -934,6 +941,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker): # or to start the worker service. if mode == PYTHON_MODE: return + # Set the node IP address. + worker.node_ip_address = info["node_ip_address"] # Create a Redis client. redis_host, redis_port = info["redis_address"].split(":") worker.redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port)) @@ -994,8 +1003,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker): # the same. script_directory = os.path.abspath(os.path.dirname(sys.argv[0])) current_directory = os.path.abspath(os.path.curdir) - worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, script_directory)) - worker.run_function_on_all_workers(lambda worker: sys.path.insert(1, current_directory)) + worker.run_function_on_all_workers(lambda worker_info: sys.path.insert(1, script_directory)) + worker.run_function_on_all_workers(lambda worker_info: sys.path.insert(1, current_directory)) # TODO(rkn): Here we first export functions to run, then reusable variables, # then remote functions. The order matters. For example, one of the # functions to run may set the Python path, which is needed to import a @@ -1054,7 +1063,7 @@ def register_class(cls, pickle=False, worker=global_worker): # Raise an exception if cls cannot be serialized efficiently by Ray. if not pickle: serialization.check_serializable(cls) - def register_class_for_serialization(worker): + def register_class_for_serialization(worker_info): serialization.add_class_to_whitelist(cls, pickle=pickle) worker.run_function_on_all_workers(register_class_for_serialization) diff --git a/lib/python/ray/workers/default_worker.py b/lib/python/ray/workers/default_worker.py index e96fe75c2..4ee01442d 100644 --- a/lib/python/ray/workers/default_worker.py +++ b/lib/python/ray/workers/default_worker.py @@ -21,7 +21,8 @@ def random_string(): if __name__ == "__main__": args = parser.parse_args() - info = {"redis_address": args.redis_address, + info = {"node_ip_address": args.node_ip_address, + "redis_address": args.redis_address, "store_socket_name": args.object_store_name, "manager_socket_name": args.object_store_manager_name, "local_scheduler_socket_name": args.local_scheduler_name} diff --git a/src/global_scheduler/global_scheduler_algorithm.c b/src/global_scheduler/global_scheduler_algorithm.c index 50c50a857..997bfbd77 100644 --- a/src/global_scheduler/global_scheduler_algorithm.c +++ b/src/global_scheduler/global_scheduler_algorithm.c @@ -51,9 +51,9 @@ object_size_entry *create_object_size_hashmap(global_scheduler_state *state, HASH_FIND(hh, state->scheduler_object_info_table, &obj_id, sizeof(obj_id), obj_info_entry); if (obj_info_entry == NULL) { - /* Global scheduler doesn't know anything about this object ID, so log a - * warning and skipt it. */ - LOG_WARN("Processing task with object ID not known to global scheduler"); + /* Global scheduler doesn't know anything about this object ID, so skip + * it. */ + LOG_DEBUG("Processing task with object ID not known to global scheduler"); continue; } LOG_DEBUG("[GS] found object id, data_size = %" PRId64, diff --git a/test/runtest.py b/test/runtest.py index ee6a2acef..6c9a5f5dc 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -410,16 +410,16 @@ class APITest(unittest.TestCase): def testCachingFunctionsToRun(self): # Test that we export functions to run on all workers before the driver is connected. - def f(worker): + def f(worker_info): sys.path.append(1) ray.worker.global_worker.run_function_on_all_workers(f) - def f(worker): + def f(worker_info): sys.path.append(2) ray.worker.global_worker.run_function_on_all_workers(f) - def g(worker): + def g(worker_info): sys.path.append(3) ray.worker.global_worker.run_function_on_all_workers(g) - def f(worker): + def f(worker_info): sys.path.append(4) ray.worker.global_worker.run_function_on_all_workers(f) @@ -436,7 +436,7 @@ class APITest(unittest.TestCase): self.assertEqual(ray.get(res2), (1, 2, 3, 4)) # Clean up the path on the workers. - def f(worker): + def f(worker_info): sys.path.pop() sys.path.pop() sys.path.pop() @@ -448,14 +448,14 @@ class APITest(unittest.TestCase): def testRunningFunctionOnAllWorkers(self): ray.init(start_ray_local=True, num_workers=1) - def f(worker): + def f(worker_info): sys.path.append("fake_directory") ray.worker.global_worker.run_function_on_all_workers(f) @ray.remote def get_path1(): return sys.path self.assertEqual("fake_directory", ray.get(get_path1.remote())[-1]) - def f(worker): + def f(worker_info): sys.path.pop(-1) ray.worker.global_worker.run_function_on_all_workers(f) # Create a second remote function to guarantee that when we call @@ -468,6 +468,33 @@ class APITest(unittest.TestCase): ray.worker.cleanup() + def testPassingInfoToAllWorkers(self): + ray.init(start_ray_local=True, num_workers=10) + + def f(worker_info): + sys.path.append(worker_info) + ray.worker.global_worker.run_function_on_all_workers(f) + @ray.remote + def get_path(): + time.sleep(1) + return sys.path + # Retrieve the values that we stored in the worker paths. + paths = ray.get([get_path.remote() for _ in range(10)]) + # Add the driver's path to the list. + paths.append(sys.path) + worker_infos = [path[-1] for path in paths] + for worker_info in worker_infos: + self.assertEqual(list(worker_info.keys()), ["counter"]) + counters = [worker_info["counter"] for worker_info in worker_infos] + # We use range(11) because the driver also runs the function. + self.assertEqual(set(counters), set(range(11))) + # Clean up the worker paths. + def f(worker_info): + sys.path.pop(-1) + ray.worker.global_worker.run_function_on_all_workers(f) + + ray.worker.cleanup() + class PythonModeTest(unittest.TestCase): def testPythonMode(self):