Unhandled exception handler based on local ref counting (#14049)

This commit is contained in:
Eric Liang
2021-02-12 22:58:38 -08:00
committed by GitHub
parent ff1b26274e
commit 9dc671ae02
11 changed files with 209 additions and 68 deletions
+22 -3
View File
@@ -724,6 +724,20 @@ cdef void delete_spilled_objects_handler(
job_id=None)
cdef void unhandled_exception_handler(const CRayObject& error) nogil:
with gil:
worker = ray.worker.global_worker
data = None
metadata = None
if error.HasData():
data = Buffer.make(error.GetData())
if error.HasMetadata():
metadata = Buffer.make(error.GetMetadata()).to_pybytes()
# TODO(ekl) why does passing a ObjectRef.nil() lead to shutdown errors?
object_ids = [None]
worker.raise_errors([(data, metadata)], object_ids)
# This function introduces ~2-7us of overhead per call (i.e., it can be called
# up to hundreds of thousands of times per second).
cdef void get_py_stack(c_string* stack_out) nogil:
@@ -833,6 +847,7 @@ cdef class CoreWorker:
options.spill_objects = spill_objects_handler
options.restore_spilled_objects = restore_spilled_objects_handler
options.delete_spilled_objects = delete_spilled_objects_handler
options.unhandled_exception_handler = unhandled_exception_handler
options.get_lang_stack = get_py_stack
options.ref_counting_enabled = True
options.is_local_mode = local_mode
@@ -1443,9 +1458,13 @@ cdef class CoreWorker:
object_ref.native())
def remove_object_ref_reference(self, ObjectRef object_ref):
# Note: faster to not release GIL for short-running op.
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
object_ref.native())
cdef:
CObjectID c_object_id = object_ref.native()
# We need to release the gil since object destruction may call the
# unhandled exception handler.
with nogil:
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
c_object_id)
def serialize_and_promote_object_ref(self, ObjectRef object_ref):
cdef:
+1
View File
@@ -250,6 +250,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
(void(
const c_vector[c_string]&,
CWorkerType) nogil) delete_spilled_objects
(void(const CRayObject&) nogil) unhandled_exception_handler
(void(c_string *stack_out) nogil) get_lang_stack
c_bool ref_counting_enabled
c_bool is_local_mode
+46
View File
@@ -20,6 +20,52 @@ from ray.test_utils import (wait_for_condition, SignalActor, init_error_pubsub,
get_error_message, Semaphore)
def test_unhandled_errors(ray_start_regular):
@ray.remote
def f():
raise ValueError()
@ray.remote
class Actor:
def f(self):
raise ValueError()
a = Actor.remote()
num_exceptions = 0
def interceptor(e):
nonlocal num_exceptions
num_exceptions += 1
# Test we report unhandled exceptions.
ray.worker._unhandled_error_handler = interceptor
x1 = f.remote()
x2 = a.f.remote()
del x1
del x2
wait_for_condition(lambda: num_exceptions == 2)
# Test we don't report handled exceptions.
x1 = f.remote()
x2 = a.f.remote()
with pytest.raises(ray.exceptions.RayError) as err: # noqa
ray.get([x1, x2])
del x1
del x2
time.sleep(1)
assert num_exceptions == 2, num_exceptions
# Test suppression with env var works.
try:
os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
x1 = f.remote()
del x1
time.sleep(1)
assert num_exceptions == 2, num_exceptions
finally:
del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"]
def test_failed_task(ray_start_regular, error_pubsub):
@ray.remote
def throw_exception_fct1():
+19 -60
View File
@@ -9,7 +9,6 @@ import json
import logging
import os
import redis
from six.moves import queue
import sys
import threading
import time
@@ -69,6 +68,12 @@ ERROR_KEY_PREFIX = b"Error:"
logger = logging.getLogger(__name__)
# Visible for testing.
def _unhandled_error_handler(e: Exception):
logger.error("Unhandled error (suppress with "
"RAY_IGNORE_UNHANDLED_ERRORS=1): {}".format(e))
class Worker:
"""A class used to define the control flow of a worker process.
@@ -277,6 +282,14 @@ class Worker:
self.core_worker.put_serialized_object(
serialized_value, object_ref=object_ref))
def raise_errors(self, data_metadata_pairs, object_refs):
context = self.get_serialization_context()
out = context.deserialize_objects(data_metadata_pairs, object_refs)
if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ:
return
for e in out:
_unhandled_error_handler(e)
def deserialize_objects(self, data_metadata_pairs, object_refs):
context = self.get_serialization_context()
return context.deserialize_objects(data_metadata_pairs, object_refs)
@@ -863,13 +876,6 @@ def custom_excepthook(type, value, tb):
sys.excepthook = custom_excepthook
# The last time we raised a TaskError in this process. We use this value to
# suppress redundant error messages pushed from the workers.
last_task_error_raise_time = 0
# The max amount of seconds to wait before printing out an uncaught error.
UNCAUGHT_ERROR_GRACE_PERIOD = 5
def print_logs(redis_client, threads_stopped, job_id):
"""Prints log messages from workers on all of the nodes.
@@ -1020,42 +1026,7 @@ def print_worker_logs(data: Dict[str, str], print_file: Any):
file=print_file)
def print_error_messages_raylet(task_error_queue, threads_stopped):
"""Prints message received in the given output queue.
This checks periodically if any un-raised errors occurred in the
background.
Args:
task_error_queue (queue.Queue): A queue used to receive errors from the
thread that listens to Redis.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
"""
while True:
# Exit if we received a signal that we should stop.
if threads_stopped.is_set():
return
try:
error, t = task_error_queue.get(block=False)
except queue.Empty:
threads_stopped.wait(timeout=0.01)
continue
# Delay errors a little bit of time to attempt to suppress redundant
# messages originating from the worker.
while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time():
threads_stopped.wait(timeout=1)
if threads_stopped.is_set():
break
if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD:
logger.debug(f"Suppressing error from worker: {error}")
else:
logger.error(f"Possible unhandled error from worker: {error}")
def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
def listen_error_messages_raylet(worker, threads_stopped):
"""Listen to error messages in the background on the driver.
This runs in a separate thread on the driver and pushes (error, time)
@@ -1063,8 +1034,6 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
Args:
worker: The worker class that this thread belongs to.
task_error_queue (queue.Queue): A queue used to communicate with the
thread that prints the errors found by this thread.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
"""
@@ -1103,8 +1072,9 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
error_message = error_data.error_message
if (error_data.type == ray_constants.TASK_PUSH_ERROR):
# Delay it a bit to see if we can suppress it
task_error_queue.put((error_message, time.time()))
# TODO(ekl) remove task push errors entirely now that we have
# the separate unhandled exception handler.
pass
else:
logger.warning(error_message)
except (OSError, redis.exceptions.ConnectionError) as e:
@@ -1267,19 +1237,12 @@ def connect(node,
# temporarily using this implementation which constantly queries the
# scheduler for new error messages.
if mode == SCRIPT_MODE:
q = queue.Queue()
worker.listener_thread = threading.Thread(
target=listen_error_messages_raylet,
name="ray_listen_error_messages",
args=(worker, q, worker.threads_stopped))
worker.printer_thread = threading.Thread(
target=print_error_messages_raylet,
name="ray_print_error_messages",
args=(q, worker.threads_stopped))
args=(worker, worker.threads_stopped))
worker.listener_thread.daemon = True
worker.listener_thread.start()
worker.printer_thread.daemon = True
worker.printer_thread.start()
if log_to_driver:
global_worker_stdstream_dispatcher.add_handler(
"ray_print_logs", print_to_stdstream)
@@ -1332,8 +1295,6 @@ def disconnect(exiting_interpreter=False):
worker.import_thread.join_import_thread()
if hasattr(worker, "listener_thread"):
worker.listener_thread.join()
if hasattr(worker, "printer_thread"):
worker.printer_thread.join()
if hasattr(worker, "logger_thread"):
worker.logger_thread.join()
worker.threads_stopped.clear()
@@ -1445,13 +1406,11 @@ def get(object_refs, *, timeout=None):
raise ValueError("'object_refs' must either be an object ref "
"or a list of object refs.")
global last_task_error_raise_time
# TODO(ujvl): Consider how to allow user to retrieve the ready objects.
values, debugger_breakpoint = worker.get_objects(
object_refs, timeout=timeout)
for i, value in enumerate(values):
if isinstance(value, RayError):
last_task_error_raise_time = time.time()
if isinstance(value, ray.exceptions.ObjectLostError):
worker.core_worker.dump_object_store_memory_usage()
if isinstance(value, RayTaskError):