Revert "Unhandled exception handler based on local ref counting (#14049)" (#14099)

This reverts commit 9dc671ae02.
This commit is contained in:
SangBin Cho
2021-02-14 22:08:32 -08:00
committed by GitHub
parent 75568f856c
commit b45ae76765
11 changed files with 68 additions and 209 deletions
+3 -22
View File
@@ -724,20 +724,6 @@ cdef void delete_spilled_objects_handler(
job_id=None)
cdef void unhandled_exception_handler(const CRayObject& error) nogil:
with gil:
worker = ray.worker.global_worker
data = None
metadata = None
if error.HasData():
data = Buffer.make(error.GetData())
if error.HasMetadata():
metadata = Buffer.make(error.GetMetadata()).to_pybytes()
# TODO(ekl) why does passing a ObjectRef.nil() lead to shutdown errors?
object_ids = [None]
worker.raise_errors([(data, metadata)], object_ids)
# This function introduces ~2-7us of overhead per call (i.e., it can be called
# up to hundreds of thousands of times per second).
cdef void get_py_stack(c_string* stack_out) nogil:
@@ -847,7 +833,6 @@ cdef class CoreWorker:
options.spill_objects = spill_objects_handler
options.restore_spilled_objects = restore_spilled_objects_handler
options.delete_spilled_objects = delete_spilled_objects_handler
options.unhandled_exception_handler = unhandled_exception_handler
options.get_lang_stack = get_py_stack
options.ref_counting_enabled = True
options.is_local_mode = local_mode
@@ -1458,13 +1443,9 @@ cdef class CoreWorker:
object_ref.native())
def remove_object_ref_reference(self, ObjectRef object_ref):
cdef:
CObjectID c_object_id = object_ref.native()
# We need to release the gil since object destruction may call the
# unhandled exception handler.
with nogil:
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
c_object_id)
# Note: faster to not release GIL for short-running op.
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
object_ref.native())
def serialize_and_promote_object_ref(self, ObjectRef object_ref):
cdef:
-1
View File
@@ -250,7 +250,6 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
(void(
const c_vector[c_string]&,
CWorkerType) nogil) delete_spilled_objects
(void(const CRayObject&) nogil) unhandled_exception_handler
(void(c_string *stack_out) nogil) get_lang_stack
c_bool ref_counting_enabled
c_bool is_local_mode
-46
View File
@@ -20,52 +20,6 @@ from ray.test_utils import (wait_for_condition, SignalActor, init_error_pubsub,
get_error_message, Semaphore)
def test_unhandled_errors(ray_start_regular):
@ray.remote
def f():
raise ValueError()
@ray.remote
class Actor:
def f(self):
raise ValueError()
a = Actor.remote()
num_exceptions = 0
def interceptor(e):
nonlocal num_exceptions
num_exceptions += 1
# Test we report unhandled exceptions.
ray.worker._unhandled_error_handler = interceptor
x1 = f.remote()
x2 = a.f.remote()
del x1
del x2
wait_for_condition(lambda: num_exceptions == 2)
# Test we don't report handled exceptions.
x1 = f.remote()
x2 = a.f.remote()
with pytest.raises(ray.exceptions.RayError) as err: # noqa
ray.get([x1, x2])
del x1
del x2
time.sleep(1)
assert num_exceptions == 2, num_exceptions
# Test suppression with env var works.
try:
os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
x1 = f.remote()
del x1
time.sleep(1)
assert num_exceptions == 2, num_exceptions
finally:
del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"]
def test_failed_task(ray_start_regular, error_pubsub):
@ray.remote
def throw_exception_fct1():
+60 -19
View File
@@ -9,6 +9,7 @@ import json
import logging
import os
import redis
from six.moves import queue
import sys
import threading
import time
@@ -68,12 +69,6 @@ ERROR_KEY_PREFIX = b"Error:"
logger = logging.getLogger(__name__)
# Visible for testing.
def _unhandled_error_handler(e: Exception):
logger.error("Unhandled error (suppress with "
"RAY_IGNORE_UNHANDLED_ERRORS=1): {}".format(e))
class Worker:
"""A class used to define the control flow of a worker process.
@@ -282,14 +277,6 @@ class Worker:
self.core_worker.put_serialized_object(
serialized_value, object_ref=object_ref))
def raise_errors(self, data_metadata_pairs, object_refs):
context = self.get_serialization_context()
out = context.deserialize_objects(data_metadata_pairs, object_refs)
if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ:
return
for e in out:
_unhandled_error_handler(e)
def deserialize_objects(self, data_metadata_pairs, object_refs):
context = self.get_serialization_context()
return context.deserialize_objects(data_metadata_pairs, object_refs)
@@ -876,6 +863,13 @@ def custom_excepthook(type, value, tb):
sys.excepthook = custom_excepthook
# The last time we raised a TaskError in this process. We use this value to
# suppress redundant error messages pushed from the workers.
last_task_error_raise_time = 0
# The max amount of seconds to wait before printing out an uncaught error.
UNCAUGHT_ERROR_GRACE_PERIOD = 5
def print_logs(redis_client, threads_stopped, job_id):
"""Prints log messages from workers on all of the nodes.
@@ -1026,7 +1020,42 @@ def print_worker_logs(data: Dict[str, str], print_file: Any):
file=print_file)
def listen_error_messages_raylet(worker, threads_stopped):
def print_error_messages_raylet(task_error_queue, threads_stopped):
"""Prints message received in the given output queue.
This checks periodically if any un-raised errors occurred in the
background.
Args:
task_error_queue (queue.Queue): A queue used to receive errors from the
thread that listens to Redis.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
"""
while True:
# Exit if we received a signal that we should stop.
if threads_stopped.is_set():
return
try:
error, t = task_error_queue.get(block=False)
except queue.Empty:
threads_stopped.wait(timeout=0.01)
continue
# Delay errors a little bit of time to attempt to suppress redundant
# messages originating from the worker.
while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time():
threads_stopped.wait(timeout=1)
if threads_stopped.is_set():
break
if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD:
logger.debug(f"Suppressing error from worker: {error}")
else:
logger.error(f"Possible unhandled error from worker: {error}")
def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
"""Listen to error messages in the background on the driver.
This runs in a separate thread on the driver and pushes (error, time)
@@ -1034,6 +1063,8 @@ def listen_error_messages_raylet(worker, threads_stopped):
Args:
worker: The worker class that this thread belongs to.
task_error_queue (queue.Queue): A queue used to communicate with the
thread that prints the errors found by this thread.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
"""
@@ -1072,9 +1103,8 @@ def listen_error_messages_raylet(worker, threads_stopped):
error_message = error_data.error_message
if (error_data.type == ray_constants.TASK_PUSH_ERROR):
# TODO(ekl) remove task push errors entirely now that we have
# the separate unhandled exception handler.
pass
# Delay it a bit to see if we can suppress it
task_error_queue.put((error_message, time.time()))
else:
logger.warning(error_message)
except (OSError, redis.exceptions.ConnectionError) as e:
@@ -1237,12 +1267,19 @@ def connect(node,
# temporarily using this implementation which constantly queries the
# scheduler for new error messages.
if mode == SCRIPT_MODE:
q = queue.Queue()
worker.listener_thread = threading.Thread(
target=listen_error_messages_raylet,
name="ray_listen_error_messages",
args=(worker, worker.threads_stopped))
args=(worker, q, worker.threads_stopped))
worker.printer_thread = threading.Thread(
target=print_error_messages_raylet,
name="ray_print_error_messages",
args=(q, worker.threads_stopped))
worker.listener_thread.daemon = True
worker.listener_thread.start()
worker.printer_thread.daemon = True
worker.printer_thread.start()
if log_to_driver:
global_worker_stdstream_dispatcher.add_handler(
"ray_print_logs", print_to_stdstream)
@@ -1295,6 +1332,8 @@ def disconnect(exiting_interpreter=False):
worker.import_thread.join_import_thread()
if hasattr(worker, "listener_thread"):
worker.listener_thread.join()
if hasattr(worker, "printer_thread"):
worker.printer_thread.join()
if hasattr(worker, "logger_thread"):
worker.logger_thread.join()
worker.threads_stopped.clear()
@@ -1406,11 +1445,13 @@ def get(object_refs, *, timeout=None):
raise ValueError("'object_refs' must either be an object ref "
"or a list of object refs.")
global last_task_error_raise_time
# TODO(ujvl): Consider how to allow user to retrieve the ready objects.
values, debugger_breakpoint = worker.get_objects(
object_refs, timeout=timeout)
for i, value in enumerate(values):
if isinstance(value, RayError):
last_task_error_raise_time = time.time()
if isinstance(value, ray.exceptions.ObjectLostError):
worker.core_worker.dump_object_store_memory_usage()
if isinstance(value, RayTaskError):