mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:06:31 +08:00
This reverts commit 9dc671ae02.
This commit is contained in:
+3
-22
@@ -724,20 +724,6 @@ cdef void delete_spilled_objects_handler(
|
||||
job_id=None)
|
||||
|
||||
|
||||
cdef void unhandled_exception_handler(const CRayObject& error) nogil:
|
||||
with gil:
|
||||
worker = ray.worker.global_worker
|
||||
data = None
|
||||
metadata = None
|
||||
if error.HasData():
|
||||
data = Buffer.make(error.GetData())
|
||||
if error.HasMetadata():
|
||||
metadata = Buffer.make(error.GetMetadata()).to_pybytes()
|
||||
# TODO(ekl) why does passing a ObjectRef.nil() lead to shutdown errors?
|
||||
object_ids = [None]
|
||||
worker.raise_errors([(data, metadata)], object_ids)
|
||||
|
||||
|
||||
# This function introduces ~2-7us of overhead per call (i.e., it can be called
|
||||
# up to hundreds of thousands of times per second).
|
||||
cdef void get_py_stack(c_string* stack_out) nogil:
|
||||
@@ -847,7 +833,6 @@ cdef class CoreWorker:
|
||||
options.spill_objects = spill_objects_handler
|
||||
options.restore_spilled_objects = restore_spilled_objects_handler
|
||||
options.delete_spilled_objects = delete_spilled_objects_handler
|
||||
options.unhandled_exception_handler = unhandled_exception_handler
|
||||
options.get_lang_stack = get_py_stack
|
||||
options.ref_counting_enabled = True
|
||||
options.is_local_mode = local_mode
|
||||
@@ -1458,13 +1443,9 @@ cdef class CoreWorker:
|
||||
object_ref.native())
|
||||
|
||||
def remove_object_ref_reference(self, ObjectRef object_ref):
|
||||
cdef:
|
||||
CObjectID c_object_id = object_ref.native()
|
||||
# We need to release the gil since object destruction may call the
|
||||
# unhandled exception handler.
|
||||
with nogil:
|
||||
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
|
||||
c_object_id)
|
||||
# Note: faster to not release GIL for short-running op.
|
||||
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
|
||||
object_ref.native())
|
||||
|
||||
def serialize_and_promote_object_ref(self, ObjectRef object_ref):
|
||||
cdef:
|
||||
|
||||
@@ -250,7 +250,6 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
(void(
|
||||
const c_vector[c_string]&,
|
||||
CWorkerType) nogil) delete_spilled_objects
|
||||
(void(const CRayObject&) nogil) unhandled_exception_handler
|
||||
(void(c_string *stack_out) nogil) get_lang_stack
|
||||
c_bool ref_counting_enabled
|
||||
c_bool is_local_mode
|
||||
|
||||
@@ -20,52 +20,6 @@ from ray.test_utils import (wait_for_condition, SignalActor, init_error_pubsub,
|
||||
get_error_message, Semaphore)
|
||||
|
||||
|
||||
def test_unhandled_errors(ray_start_regular):
|
||||
@ray.remote
|
||||
def f():
|
||||
raise ValueError()
|
||||
|
||||
@ray.remote
|
||||
class Actor:
|
||||
def f(self):
|
||||
raise ValueError()
|
||||
|
||||
a = Actor.remote()
|
||||
num_exceptions = 0
|
||||
|
||||
def interceptor(e):
|
||||
nonlocal num_exceptions
|
||||
num_exceptions += 1
|
||||
|
||||
# Test we report unhandled exceptions.
|
||||
ray.worker._unhandled_error_handler = interceptor
|
||||
x1 = f.remote()
|
||||
x2 = a.f.remote()
|
||||
del x1
|
||||
del x2
|
||||
wait_for_condition(lambda: num_exceptions == 2)
|
||||
|
||||
# Test we don't report handled exceptions.
|
||||
x1 = f.remote()
|
||||
x2 = a.f.remote()
|
||||
with pytest.raises(ray.exceptions.RayError) as err: # noqa
|
||||
ray.get([x1, x2])
|
||||
del x1
|
||||
del x2
|
||||
time.sleep(1)
|
||||
assert num_exceptions == 2, num_exceptions
|
||||
|
||||
# Test suppression with env var works.
|
||||
try:
|
||||
os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
|
||||
x1 = f.remote()
|
||||
del x1
|
||||
time.sleep(1)
|
||||
assert num_exceptions == 2, num_exceptions
|
||||
finally:
|
||||
del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"]
|
||||
|
||||
|
||||
def test_failed_task(ray_start_regular, error_pubsub):
|
||||
@ray.remote
|
||||
def throw_exception_fct1():
|
||||
|
||||
+60
-19
@@ -9,6 +9,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import redis
|
||||
from six.moves import queue
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@@ -68,12 +69,6 @@ ERROR_KEY_PREFIX = b"Error:"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Visible for testing.
|
||||
def _unhandled_error_handler(e: Exception):
|
||||
logger.error("Unhandled error (suppress with "
|
||||
"RAY_IGNORE_UNHANDLED_ERRORS=1): {}".format(e))
|
||||
|
||||
|
||||
class Worker:
|
||||
"""A class used to define the control flow of a worker process.
|
||||
|
||||
@@ -282,14 +277,6 @@ class Worker:
|
||||
self.core_worker.put_serialized_object(
|
||||
serialized_value, object_ref=object_ref))
|
||||
|
||||
def raise_errors(self, data_metadata_pairs, object_refs):
|
||||
context = self.get_serialization_context()
|
||||
out = context.deserialize_objects(data_metadata_pairs, object_refs)
|
||||
if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ:
|
||||
return
|
||||
for e in out:
|
||||
_unhandled_error_handler(e)
|
||||
|
||||
def deserialize_objects(self, data_metadata_pairs, object_refs):
|
||||
context = self.get_serialization_context()
|
||||
return context.deserialize_objects(data_metadata_pairs, object_refs)
|
||||
@@ -876,6 +863,13 @@ def custom_excepthook(type, value, tb):
|
||||
|
||||
sys.excepthook = custom_excepthook
|
||||
|
||||
# The last time we raised a TaskError in this process. We use this value to
|
||||
# suppress redundant error messages pushed from the workers.
|
||||
last_task_error_raise_time = 0
|
||||
|
||||
# The max amount of seconds to wait before printing out an uncaught error.
|
||||
UNCAUGHT_ERROR_GRACE_PERIOD = 5
|
||||
|
||||
|
||||
def print_logs(redis_client, threads_stopped, job_id):
|
||||
"""Prints log messages from workers on all of the nodes.
|
||||
@@ -1026,7 +1020,42 @@ def print_worker_logs(data: Dict[str, str], print_file: Any):
|
||||
file=print_file)
|
||||
|
||||
|
||||
def listen_error_messages_raylet(worker, threads_stopped):
|
||||
def print_error_messages_raylet(task_error_queue, threads_stopped):
|
||||
"""Prints message received in the given output queue.
|
||||
|
||||
This checks periodically if any un-raised errors occurred in the
|
||||
background.
|
||||
|
||||
Args:
|
||||
task_error_queue (queue.Queue): A queue used to receive errors from the
|
||||
thread that listens to Redis.
|
||||
threads_stopped (threading.Event): A threading event used to signal to
|
||||
the thread that it should exit.
|
||||
"""
|
||||
|
||||
while True:
|
||||
# Exit if we received a signal that we should stop.
|
||||
if threads_stopped.is_set():
|
||||
return
|
||||
|
||||
try:
|
||||
error, t = task_error_queue.get(block=False)
|
||||
except queue.Empty:
|
||||
threads_stopped.wait(timeout=0.01)
|
||||
continue
|
||||
# Delay errors a little bit of time to attempt to suppress redundant
|
||||
# messages originating from the worker.
|
||||
while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time():
|
||||
threads_stopped.wait(timeout=1)
|
||||
if threads_stopped.is_set():
|
||||
break
|
||||
if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD:
|
||||
logger.debug(f"Suppressing error from worker: {error}")
|
||||
else:
|
||||
logger.error(f"Possible unhandled error from worker: {error}")
|
||||
|
||||
|
||||
def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
|
||||
"""Listen to error messages in the background on the driver.
|
||||
|
||||
This runs in a separate thread on the driver and pushes (error, time)
|
||||
@@ -1034,6 +1063,8 @@ def listen_error_messages_raylet(worker, threads_stopped):
|
||||
|
||||
Args:
|
||||
worker: The worker class that this thread belongs to.
|
||||
task_error_queue (queue.Queue): A queue used to communicate with the
|
||||
thread that prints the errors found by this thread.
|
||||
threads_stopped (threading.Event): A threading event used to signal to
|
||||
the thread that it should exit.
|
||||
"""
|
||||
@@ -1072,9 +1103,8 @@ def listen_error_messages_raylet(worker, threads_stopped):
|
||||
|
||||
error_message = error_data.error_message
|
||||
if (error_data.type == ray_constants.TASK_PUSH_ERROR):
|
||||
# TODO(ekl) remove task push errors entirely now that we have
|
||||
# the separate unhandled exception handler.
|
||||
pass
|
||||
# Delay it a bit to see if we can suppress it
|
||||
task_error_queue.put((error_message, time.time()))
|
||||
else:
|
||||
logger.warning(error_message)
|
||||
except (OSError, redis.exceptions.ConnectionError) as e:
|
||||
@@ -1237,12 +1267,19 @@ def connect(node,
|
||||
# temporarily using this implementation which constantly queries the
|
||||
# scheduler for new error messages.
|
||||
if mode == SCRIPT_MODE:
|
||||
q = queue.Queue()
|
||||
worker.listener_thread = threading.Thread(
|
||||
target=listen_error_messages_raylet,
|
||||
name="ray_listen_error_messages",
|
||||
args=(worker, worker.threads_stopped))
|
||||
args=(worker, q, worker.threads_stopped))
|
||||
worker.printer_thread = threading.Thread(
|
||||
target=print_error_messages_raylet,
|
||||
name="ray_print_error_messages",
|
||||
args=(q, worker.threads_stopped))
|
||||
worker.listener_thread.daemon = True
|
||||
worker.listener_thread.start()
|
||||
worker.printer_thread.daemon = True
|
||||
worker.printer_thread.start()
|
||||
if log_to_driver:
|
||||
global_worker_stdstream_dispatcher.add_handler(
|
||||
"ray_print_logs", print_to_stdstream)
|
||||
@@ -1295,6 +1332,8 @@ def disconnect(exiting_interpreter=False):
|
||||
worker.import_thread.join_import_thread()
|
||||
if hasattr(worker, "listener_thread"):
|
||||
worker.listener_thread.join()
|
||||
if hasattr(worker, "printer_thread"):
|
||||
worker.printer_thread.join()
|
||||
if hasattr(worker, "logger_thread"):
|
||||
worker.logger_thread.join()
|
||||
worker.threads_stopped.clear()
|
||||
@@ -1406,11 +1445,13 @@ def get(object_refs, *, timeout=None):
|
||||
raise ValueError("'object_refs' must either be an object ref "
|
||||
"or a list of object refs.")
|
||||
|
||||
global last_task_error_raise_time
|
||||
# TODO(ujvl): Consider how to allow user to retrieve the ready objects.
|
||||
values, debugger_breakpoint = worker.get_objects(
|
||||
object_refs, timeout=timeout)
|
||||
for i, value in enumerate(values):
|
||||
if isinstance(value, RayError):
|
||||
last_task_error_raise_time = time.time()
|
||||
if isinstance(value, ray.exceptions.ObjectLostError):
|
||||
worker.core_worker.dump_object_store_memory_usage()
|
||||
if isinstance(value, RayTaskError):
|
||||
|
||||
Reference in New Issue
Block a user