From f9f667de474b7d249025e49dc95108568ff1ffe7 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Thu, 29 Dec 2016 00:11:13 -0800 Subject: [PATCH] Improve formatting of error messages. (#154) * Improve formatting of error messages. * Catch errors that occur when looking up function name from function ID. * Push warning to user if worker spends to long waiting for proper import counter. * Fixes. * Add comment. --- lib/python/ray/worker.py | 118 ++++++++++++++++++++++++++++++--------- 1 file changed, 93 insertions(+), 25 deletions(-) diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 6bd3f89f2..18176cdca 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -596,7 +596,8 @@ def error_info(worker=global_worker): b"RemoteFunctionImportError": [], b"ReusableVariableImportError": [], b"ReusableVariableReinitializeError": [], - b"FunctionToRunError": [] + b"FunctionToRunError": [], + b"GenericWarning": [], } error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1) for error_key in error_keys: @@ -857,6 +858,20 @@ def print_error_messages(worker): This runs in a separate thread on the driver and prints error messages in the background. """ + # TODO(rkn): All error messages should have a "component" field indicating + # which process the error came from (e.g., a worker or a plasma store). + # Currently all error messages come from workers. + + helpful_message = """ +You can inspect errors by running + + ray.error_info() + +If this driver is hanging, start a new one with + + ray.init(redis_address="{}") +""".format(worker.redis_address) + worker.error_message_pubsub_client = worker.redis_client.pubsub() # Exports that are published after the call to # error_message_pubsub_client.psubscribe and before the call to @@ -870,6 +885,7 @@ def print_error_messages(worker): for error_key in error_keys: error_message = worker.redis_client.hget(error_key, "message").decode("ascii") print(error_message) + print(helpful_message) num_errors_printed += 1 try: @@ -878,6 +894,7 @@ def print_error_messages(worker): for error_key in worker.redis_client.lrange("ErrorKeys", num_errors_printed, -1): error_message = worker.redis_client.hget(error_key, "message").decode("ascii") print(error_message) + print(helpful_message) num_errors_printed += 1 except redis.ConnectionError: # When Redis terminates the listen call will throw a ConnectionError, which @@ -1027,6 +1044,7 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker): return # Set the node IP address. worker.node_ip_address = info["node_ip_address"] + worker.redis_address = info["redis_address"] # Create a Redis client. redis_host, redis_port = info["redis_address"].split(":") worker.redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port)) @@ -1232,7 +1250,36 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): remaining_ids = [photon.ObjectID(object_id) for object_id in remaining_ids] return ready_ids, remaining_ids -def format_error_message(exception_message): +def wait_for_valid_import_counter(function_id, timeout=5, worker=global_worker): + """Wait until this worker has imported enough to execute the function. + + This method will simply loop until the import thread has imported enough of + the exports to execute the function. If we spend too long in this loop, that + may indicate a problem somewhere and we will push an error message to the + user. + + Args: + function_id (str): The ID of the function that we want to execute. + """ + start_time = time.time() + # Only send the warning once. + warning_sent = False + num_warnings_sent = 0 + while True: + with worker.lock: + if function_id.id() in worker.functions and (worker.function_export_counters[function_id.id()] <= worker.worker_import_counter): + break + if time.time() - start_time > timeout * (num_warnings_sent + 1): + if function_id.id() not in worker.functions: + warning_message = "This worker was asked to execute a function that it does not have registered. You may have to restart Ray." + else: + warning_message = "This worker's import counter is too small." + if not warning_sent: + push_warning_to_user(warning_message, worker=worker) + warning_sent = True + time.sleep(0.001) + +def format_error_message(exception_message, task_exception=False): """Improve the formatting of an exception thrown by a remote function. This method takes a traceback from an exception and makes it nicer by @@ -1246,9 +1293,11 @@ def format_error_message(exception_message): A string of the formatted exception message. """ lines = exception_message.split("\n") - # Remove lines 1, 2, 3, and 4, which are always the same, they just contain - # information about the main loop. - lines = lines[0:1] + lines[5:] + if task_exception: + # For errors that occur inside of tasks, remove lines 1, 2, 3, and 4, + # which are always the same, they just contain information about the main + # loop. + lines = lines[0:1] + lines[5:] return "\n".join(lines) def main_loop(worker=global_worker): @@ -1272,24 +1321,37 @@ def main_loop(worker=global_worker): After the task executes, the worker resets any reusable variables that were accessed by the task. """ - worker.current_task_id = task.task_id() - worker.task_index = 0 - worker.put_index = 0 - function_id = task.function_id() - args = task.arguments() - return_object_ids = task.returns() - function_name = worker.function_names[function_id.id()] try: - arguments = get_arguments_for_execution(worker.functions[function_id.id()], args, worker) # get args from objstore - outputs = worker.functions[function_id.id()].executor(arguments) # execute the function + worker.current_task_id = task.task_id() + worker.task_index = 0 + worker.put_index = 0 + function_id = task.function_id() + args = task.arguments() + return_object_ids = task.returns() + function_name = worker.function_names[function_id.id()] + # Get task arguments from the object store. + arguments = get_arguments_for_execution(worker.functions[function_id.id()], args, worker) + # Execute the task. + outputs = worker.functions[function_id.id()].executor(arguments) + # Store the outputs in the local object store. if len(return_object_ids) == 1: outputs = (outputs,) - store_outputs_in_objstore(return_object_ids, outputs, worker) # store output in local object store + store_outputs_in_objstore(return_object_ids, outputs, worker) except Exception as e: - # If the task threw an exception, then record the traceback. We determine - # whether the exception was thrown in the task execution by whether the - # variable "arguments" is defined. - traceback_str = format_error_message(traceback.format_exc()) if "arguments" in locals() else None + # We determine whether the exception was caused by the call to + # get_arguments_for_execution or by the execution of the remote function + # or by the call to store_outputs_in_objstore. Depending on which case + # occurred, we format the error message differently. + # whether the variables "arguments" and "outputs" are defined. + if "arguments" in locals() and "outputs" not in locals(): + # The error occurred during the task execution. + traceback_str = format_error_message(traceback.format_exc(), task_exception=True) + elif "arguments" in locals() and "outputs" in locals(): + # The error occurred after the task executed. + traceback_str = format_error_message(traceback.format_exc()) + else: + # The error occurred before the task execution. + traceback_str = None failure_object = RayTaskError(function_name, e, traceback_str) failure_objects = [failure_object for _ in range(len(return_object_ids))] store_outputs_in_objstore(return_object_ids, failure_objects, worker) @@ -1297,7 +1359,7 @@ def main_loop(worker=global_worker): error_key = "TaskError:{}".format(random_string()) worker.redis_client.hmset(error_key, {"function_id": function_id.id(), "function_name": function_name, - "message": traceback_str}) + "message": str(failure_object)}) worker.redis_client.rpush("ErrorKeys", error_key) try: # Reinitialize the values of reusable variables that were used in the task @@ -1320,15 +1382,21 @@ def main_loop(worker=global_worker): function_id = task.function_id() # Check that the number of imports we have is at least as great as the # export counter for the task. If not, wait until we have imported enough. - while True: - with worker.lock: - if function_id.id() in worker.functions and (worker.function_export_counters[function_id.id()] <= worker.worker_import_counter): - break - time.sleep(0.001) + # We will push warnings to the user if we spend too long in this loop. + wait_for_valid_import_counter(function_id, worker=worker) # Execute the task. + # TODO(rkn): Consider acquiring this lock with a timeout and pushing a + # warning to the user if we are waiting too long to acquire the lock because + # that may indicate that the system is hanging, and it'd be good to know + # where the system is hanging. with worker.lock: process_task(task) +def push_warning_to_user(message, worker=global_worker): + error_key = "GenericWarning:{}".format(random_string()) + worker.redis_client.hmset(error_key, {"message": message}) + worker.redis_client.rpush("ErrorKeys", error_key) + def _submit_task(function_id, func_name, args, worker=global_worker): """This is a wrapper around worker.submit_task.