From 4eade036a0505e244c976f36aaa2d64386b5129b Mon Sep 17 00:00:00 2001 From: Yuhong Guo Date: Mon, 29 Apr 2019 14:55:37 +0800 Subject: [PATCH] Separate thread locks for worker and function manager. (#4499) * Separate lock for function manager and worker * Lint * Add test case * Remove print in remote function. * Remove test and add ray.exit_actor * Update python/ray/worker.py Co-Authored-By: guoyuhong * Move exit_actor from worker.py to actor.py * Update actor.py * Update actor.py --- python/ray/actor.py | 31 ++++++++--- python/ray/function_manager.py | 99 ++++++++++++++++++---------------- python/ray/import_thread.py | 28 +++++----- python/ray/worker.py | 86 ++++++++++++++--------------- 4 files changed, 131 insertions(+), 113 deletions(-) diff --git a/python/ray/actor.py b/python/ray/actor.py index 7953308df..3917042f2 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -708,12 +708,7 @@ def make_actor(cls, num_cpus, num_gpus, resources, max_reconstructions): def __ray_terminate__(self): worker = ray.worker.get_global_worker() if worker.mode != ray.LOCAL_MODE: - # Disconnect the worker from the raylet. The point of - # this is so that when the worker kills itself below, the - # raylet won't push an error message to the driver. - worker.raylet_client.disconnect() - sys.exit(0) - assert False, "This process should have terminated." + ray.actor.exit_actor() def __ray_checkpoint__(self): """Save a checkpoint. @@ -738,6 +733,30 @@ def make_actor(cls, num_cpus, num_gpus, resources, max_reconstructions): resources) +def exit_actor(): + """Intentionally exit the current actor. + + This function is used to disconnect an actor and exit the worker. + + Raises: + Exception: An exception is raised if this is a driver or this + worker is not an actor. + """ + worker = ray.worker.global_worker + if worker.mode == ray.WORKER_MODE and not worker.actor_id.is_nil(): + # Disconnect the worker from the raylet. The point of + # this is so that when the worker kills itself below, the + # raylet won't push an error message to the driver. + worker.raylet_client.disconnect() + ray.disconnect() + # Disconnect global state from GCS. + ray.global_state.disconnect() + sys.exit(0) + assert False, "This process should have terminated." + else: + raise Exception("exit_actor called on a non-actor worker.") + + ray.worker.global_worker.make_actor = make_actor CheckpointContext = namedtuple( diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index 7651f0d62..e4a172fc1 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -9,6 +9,7 @@ import json import logging import sys import time +import threading import traceback from collections import ( namedtuple, @@ -300,6 +301,7 @@ class FunctionActorManager(object): # these types. self.imported_actor_classes = set() self._loaded_actor_classes = {} + self.lock = threading.Lock() def increase_task_counter(self, driver_id, function_descriptor): function_id = function_descriptor.function_id @@ -407,41 +409,48 @@ class FunctionActorManager(object): def f(): raise Exception("This function was not imported properly.") - self._function_execution_info[driver_id][function_id] = ( - FunctionExecutionInfo( - function=f, function_name=function_name, max_calls=max_calls)) - self._num_task_executions[driver_id][function_id] = 0 - - try: - function = pickle.loads(serialized_function) - except Exception: - # If an exception was thrown when the remote function was imported, - # we record the traceback and notify the scheduler of the failure. - traceback_str = format_error_message(traceback.format_exc()) - # Log the error message. - push_error_to_driver( - self._worker, - ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR, - "Failed to unpickle the remote function '{}' with function ID " - "{}. Traceback:\n{}".format(function_name, function_id.hex(), - traceback_str), - driver_id=driver_id) - else: - # The below line is necessary. Because in the driver process, - # if the function is defined in the file where the python script - # was started from, its module is `__main__`. - # However in the worker process, the `__main__` module is a - # different module, which is `default_worker.py` - function.__module__ = module + # This function is called by ImportThread. This operation needs to be + # atomic. Otherwise, there is race condition. Another thread may use + # the temporary function above before the real function is ready. + with self.lock: self._function_execution_info[driver_id][function_id] = ( FunctionExecutionInfo( - function=function, + function=f, function_name=function_name, max_calls=max_calls)) - # Add the function to the function table. - self._worker.redis_client.rpush( - b"FunctionTable:" + function_id.binary(), - self._worker.worker_id) + self._num_task_executions[driver_id][function_id] = 0 + + try: + function = pickle.loads(serialized_function) + except Exception: + # If an exception was thrown when the remote function was + # imported, we record the traceback and notify the scheduler + # of the failure. + traceback_str = format_error_message(traceback.format_exc()) + # Log the error message. + push_error_to_driver( + self._worker, + ray_constants.REGISTER_REMOTE_FUNCTION_PUSH_ERROR, + "Failed to unpickle the remote function '{}' with " + "function ID {}. Traceback:\n{}".format( + function_name, function_id.hex(), traceback_str), + driver_id=driver_id) + else: + # The below line is necessary. Because in the driver process, + # if the function is defined in the file where the python + # script was started from, its module is `__main__`. + # However in the worker process, the `__main__` module is a + # different module, which is `default_worker.py` + function.__module__ = module + self._function_execution_info[driver_id][function_id] = ( + FunctionExecutionInfo( + function=function, + function_name=function_name, + max_calls=max_calls)) + # Add the function to the function table. + self._worker.redis_client.rpush( + b"FunctionTable:" + function_id.binary(), + self._worker.worker_id) def get_execution_info(self, driver_id, function_descriptor): """Get the FunctionExecutionInfo of a remote function. @@ -526,7 +535,7 @@ class FunctionActorManager(object): # Only send the warning once. warning_sent = False while True: - with self._worker.lock: + with self.lock: if (self._worker.actor_id.is_nil() and (function_descriptor.function_id in self._function_execution_info[driver_id])): @@ -534,18 +543,18 @@ class FunctionActorManager(object): elif not self._worker.actor_id.is_nil() and ( self._worker.actor_id in self._worker.actors): break - if time.time() - start_time > timeout: - warning_message = ("This worker was asked to execute a " - "function that it does not have " - "registered. You may have to restart " - "Ray.") - if not warning_sent: - ray.utils.push_error_to_driver( - self._worker, - ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR, - warning_message, - driver_id=driver_id) - warning_sent = True + if time.time() - start_time > timeout: + warning_message = ("This worker was asked to execute a " + "function that it does not have " + "registered. You may have to restart " + "Ray.") + if not warning_sent: + ray.utils.push_error_to_driver( + self._worker, + ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR, + warning_message, + driver_id=driver_id) + warning_sent = True time.sleep(0.001) def _publish_actor_class_to_key(self, key, actor_class_info): @@ -716,7 +725,7 @@ class FunctionActorManager(object): actor_class = None try: - with self._worker.lock: + with self.lock: actor_class = pickle.loads(pickled_class) except Exception: logger.exception( diff --git a/python/ray/import_thread.py b/python/ray/import_thread.py index 780db0be9..3de0cb079 100644 --- a/python/ray/import_thread.py +++ b/python/ray/import_thread.py @@ -56,11 +56,10 @@ class ImportThread(object): try: # Get the exports that occurred before the call to subscribe. - with self.worker.lock: - export_keys = self.redis_client.lrange("Exports", 0, -1) - for key in export_keys: - num_imported += 1 - self._process_key(key) + export_keys = self.redis_client.lrange("Exports", 0, -1) + for key in export_keys: + num_imported += 1 + self._process_key(key) while True: # Exit if we received a signal that we should stop. @@ -72,16 +71,15 @@ class ImportThread(object): self.threads_stopped.wait(timeout=0.01) continue - with self.worker.lock: - if msg["type"] == "subscribe": - continue - assert msg["data"] == b"rpush" - num_imports = self.redis_client.llen("Exports") - assert num_imports >= num_imported - for i in range(num_imported, num_imports): - num_imported += 1 - key = self.redis_client.lindex("Exports", i) - self._process_key(key) + if msg["type"] == "subscribe": + continue + assert msg["data"] == b"rpush" + num_imports = self.redis_client.llen("Exports") + assert num_imports >= num_imported + for i in range(num_imported, num_imports): + num_imported += 1 + key = self.redis_client.lindex("Exports", i) + self._process_key(key) finally: # Close the pubsub client to avoid leaking file descriptors. import_pubsub_client.close() diff --git a/python/ray/worker.py b/python/ray/worker.py index 5c4d416b9..b8fd28b30 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -234,9 +234,14 @@ class Worker(object): Returns: The serialization context of the given driver. """ - if driver_id not in self.serialization_context_map: - _initialize_serialization(driver_id) - return self.serialization_context_map[driver_id] + # This function needs to be proctected by a lock, because it will be + # called by`register_class_for_serialization`, as well as the import + # thread, from different threads. Also, this function will recursively + # call itself, so we use RLock here. + with self.lock: + if driver_id not in self.serialization_context_map: + _initialize_serialization(driver_id) + return self.serialization_context_map[driver_id] def check_connected(self): """Check if the worker is connected. @@ -428,11 +433,7 @@ class Worker(object): # Wait a little bit for the import thread to import the class. # If we currently have the worker lock, we need to release it # so that the import thread can acquire it. - if self.mode == WORKER_MODE: - self.lock.release() time.sleep(0.01) - if self.mode == WORKER_MODE: - self.lock.acquire() if time.time() - start_time > error_timeout: warning_message = ("This worker or driver is waiting to " @@ -968,45 +969,37 @@ class Worker(object): driver_id, function_descriptor) # Execute the task. - # TODO(rkn): Consider acquiring this lock with a timeout and pushing a - # warning to the user if we are waiting too long to acquire the lock - # because that may indicate that the system is hanging, and it'd be - # good to know where the system is hanging. - with self.lock: - function_name = execution_info.function_name - extra_data = { - "name": function_name, - "task_id": task.task_id().hex() - } - if task.actor_id().is_nil(): - if task.actor_creation_id().is_nil(): - title = "ray_worker:{}()".format(function_name) - next_title = "ray_worker" - else: - actor = self.actors[task.actor_creation_id()] - title = "ray_{}:{}()".format(actor.__class__.__name__, - function_name) - next_title = "ray_{}".format(actor.__class__.__name__) + function_name = execution_info.function_name + extra_data = {"name": function_name, "task_id": task.task_id().hex()} + if task.actor_id().is_nil(): + if task.actor_creation_id().is_nil(): + title = "ray_worker:{}()".format(function_name) + next_title = "ray_worker" else: - actor = self.actors[task.actor_id()] + actor = self.actors[task.actor_creation_id()] title = "ray_{}:{}()".format(actor.__class__.__name__, function_name) next_title = "ray_{}".format(actor.__class__.__name__) - with profiling.profile("task", extra_data=extra_data): - with _changeproctitle(title, next_title): - self._process_task(task, execution_info) - # Reset the state fields so the next task can run. - self.task_context.current_task_id = TaskID.nil() - self.task_context.task_index = 0 - self.task_context.put_index = 1 - if self.actor_id.is_nil(): - # Don't need to reset task_driver_id if the worker is an - # actor. Because the following tasks should all have the - # same driver id. - self.task_driver_id = DriverID.nil() - # Reset signal counters so that the next task can get - # all past signals. - ray_signal.reset() + else: + actor = self.actors[task.actor_id()] + title = "ray_{}:{}()".format(actor.__class__.__name__, + function_name) + next_title = "ray_{}".format(actor.__class__.__name__) + with profiling.profile("task", extra_data=extra_data): + with _changeproctitle(title, next_title): + self._process_task(task, execution_info) + # Reset the state fields so the next task can run. + self.task_context.current_task_id = TaskID.nil() + self.task_context.task_index = 0 + self.task_context.put_index = 1 + if self.actor_id.is_nil(): + # Don't need to reset task_driver_id if the worker is an + # actor. Because the following tasks should all have the + # same driver id. + self.task_driver_id = DriverID.nil() + # Reset signal counters so that the next task can get + # all past signals. + ray_signal.reset() # Increase the task execution counter. self.function_actor_manager.increase_task_counter( @@ -1645,10 +1638,9 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): try: # Get the exports that occurred before the call to subscribe. - with worker.lock: - error_messages = global_state.error_messages(worker.task_driver_id) - for error_message in error_messages: - logger.error(error_message) + error_messages = global_state.error_messages(worker.task_driver_id) + for error_message in error_messages: + logger.error(error_message) while True: # Exit if we received a signal that we should stop. @@ -1774,7 +1766,7 @@ def connect(node, traceback_str, driver_id=None) - worker.lock = threading.Lock() + worker.lock = threading.RLock() # Create an object for interfacing with the global state. global_state._initialize_global_state(