Revert "Revert "Unhandled exception handler based on local ref counti… (#14113)

* Revert "Revert "Unhandled exception handler based on local ref counting (#14049)" (#14099)"

This reverts commit b45ae76765.

* reomve test

* fix

* fix
This commit is contained in:
Eric Liang
2021-02-15 14:11:11 -08:00
committed by GitHub
parent 4846a6c2d0
commit e457872fe1
11 changed files with 210 additions and 68 deletions
+10
View File
@@ -702,6 +702,16 @@ cc_test(
],
)
cc_test(
name = "memory_store_test",
srcs = ["src/ray/core_worker/test/memory_store_test.cc"],
copts = COPTS,
deps = [
":core_worker_lib",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "direct_actor_transport_test",
srcs = ["src/ray/core_worker/test/direct_actor_transport_test.cc"],
+22 -3
View File
@@ -724,6 +724,20 @@ cdef void delete_spilled_objects_handler(
job_id=None)
cdef void unhandled_exception_handler(const CRayObject& error) nogil:
with gil:
worker = ray.worker.global_worker
data = None
metadata = None
if error.HasData():
data = Buffer.make(error.GetData())
if error.HasMetadata():
metadata = Buffer.make(error.GetMetadata()).to_pybytes()
# TODO(ekl) why does passing a ObjectRef.nil() lead to shutdown errors?
object_ids = [None]
worker.raise_errors([(data, metadata)], object_ids)
# This function introduces ~2-7us of overhead per call (i.e., it can be called
# up to hundreds of thousands of times per second).
cdef void get_py_stack(c_string* stack_out) nogil:
@@ -833,6 +847,7 @@ cdef class CoreWorker:
options.spill_objects = spill_objects_handler
options.restore_spilled_objects = restore_spilled_objects_handler
options.delete_spilled_objects = delete_spilled_objects_handler
options.unhandled_exception_handler = unhandled_exception_handler
options.get_lang_stack = get_py_stack
options.ref_counting_enabled = True
options.is_local_mode = local_mode
@@ -1443,9 +1458,13 @@ cdef class CoreWorker:
object_ref.native())
def remove_object_ref_reference(self, ObjectRef object_ref):
# Note: faster to not release GIL for short-running op.
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
object_ref.native())
cdef:
CObjectID c_object_id = object_ref.native()
# We need to release the gil since object destruction may call the
# unhandled exception handler.
with nogil:
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
c_object_id)
def serialize_and_promote_object_ref(self, ObjectRef object_ref):
cdef:
+1
View File
@@ -250,6 +250,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
(void(
const c_vector[c_string]&,
CWorkerType) nogil) delete_spilled_objects
(void(const CRayObject&) nogil) unhandled_exception_handler
(void(c_string *stack_out) nogil) get_lang_stack
c_bool ref_counting_enabled
c_bool is_local_mode
+46
View File
@@ -20,6 +20,52 @@ from ray.test_utils import (wait_for_condition, SignalActor, init_error_pubsub,
get_error_message, Semaphore)
def test_unhandled_errors(ray_start_regular):
@ray.remote
def f():
raise ValueError()
@ray.remote
class Actor:
def f(self):
raise ValueError()
a = Actor.remote()
num_exceptions = 0
def interceptor(e):
nonlocal num_exceptions
num_exceptions += 1
# Test we report unhandled exceptions.
ray.worker._unhandled_error_handler = interceptor
x1 = f.remote()
x2 = a.f.remote()
del x1
del x2
wait_for_condition(lambda: num_exceptions == 2)
# Test we don't report handled exceptions.
x1 = f.remote()
x2 = a.f.remote()
with pytest.raises(ray.exceptions.RayError) as err: # noqa
ray.get([x1, x2])
del x1
del x2
time.sleep(1)
assert num_exceptions == 2, num_exceptions
# Test suppression with env var works.
try:
os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
x1 = f.remote()
del x1
time.sleep(1)
assert num_exceptions == 2, num_exceptions
finally:
del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"]
def test_failed_task(ray_start_regular, error_pubsub):
@ray.remote
def throw_exception_fct1():
+19 -60
View File
@@ -9,7 +9,6 @@ import json
import logging
import os
import redis
from six.moves import queue
import sys
import threading
import time
@@ -69,6 +68,12 @@ ERROR_KEY_PREFIX = b"Error:"
logger = logging.getLogger(__name__)
# Visible for testing.
def _unhandled_error_handler(e: Exception):
logger.error("Unhandled error (suppress with "
"RAY_IGNORE_UNHANDLED_ERRORS=1): {}".format(e))
class Worker:
"""A class used to define the control flow of a worker process.
@@ -277,6 +282,14 @@ class Worker:
self.core_worker.put_serialized_object(
serialized_value, object_ref=object_ref))
def raise_errors(self, data_metadata_pairs, object_refs):
context = self.get_serialization_context()
out = context.deserialize_objects(data_metadata_pairs, object_refs)
if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ:
return
for e in out:
_unhandled_error_handler(e)
def deserialize_objects(self, data_metadata_pairs, object_refs):
context = self.get_serialization_context()
return context.deserialize_objects(data_metadata_pairs, object_refs)
@@ -863,13 +876,6 @@ def custom_excepthook(type, value, tb):
sys.excepthook = custom_excepthook
# The last time we raised a TaskError in this process. We use this value to
# suppress redundant error messages pushed from the workers.
last_task_error_raise_time = 0
# The max amount of seconds to wait before printing out an uncaught error.
UNCAUGHT_ERROR_GRACE_PERIOD = 5
def print_logs(redis_client, threads_stopped, job_id):
"""Prints log messages from workers on all of the nodes.
@@ -1020,42 +1026,7 @@ def print_worker_logs(data: Dict[str, str], print_file: Any):
file=print_file)
def print_error_messages_raylet(task_error_queue, threads_stopped):
"""Prints message received in the given output queue.
This checks periodically if any un-raised errors occurred in the
background.
Args:
task_error_queue (queue.Queue): A queue used to receive errors from the
thread that listens to Redis.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
"""
while True:
# Exit if we received a signal that we should stop.
if threads_stopped.is_set():
return
try:
error, t = task_error_queue.get(block=False)
except queue.Empty:
threads_stopped.wait(timeout=0.01)
continue
# Delay errors a little bit of time to attempt to suppress redundant
# messages originating from the worker.
while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time():
threads_stopped.wait(timeout=1)
if threads_stopped.is_set():
break
if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD:
logger.debug(f"Suppressing error from worker: {error}")
else:
logger.error(f"Possible unhandled error from worker: {error}")
def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
def listen_error_messages_raylet(worker, threads_stopped):
"""Listen to error messages in the background on the driver.
This runs in a separate thread on the driver and pushes (error, time)
@@ -1063,8 +1034,6 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
Args:
worker: The worker class that this thread belongs to.
task_error_queue (queue.Queue): A queue used to communicate with the
thread that prints the errors found by this thread.
threads_stopped (threading.Event): A threading event used to signal to
the thread that it should exit.
"""
@@ -1103,8 +1072,9 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
error_message = error_data.error_message
if (error_data.type == ray_constants.TASK_PUSH_ERROR):
# Delay it a bit to see if we can suppress it
task_error_queue.put((error_message, time.time()))
# TODO(ekl) remove task push errors entirely now that we have
# the separate unhandled exception handler.
pass
else:
logger.warning(error_message)
except (OSError, redis.exceptions.ConnectionError) as e:
@@ -1267,19 +1237,12 @@ def connect(node,
# temporarily using this implementation which constantly queries the
# scheduler for new error messages.
if mode == SCRIPT_MODE:
q = queue.Queue()
worker.listener_thread = threading.Thread(
target=listen_error_messages_raylet,
name="ray_listen_error_messages",
args=(worker, q, worker.threads_stopped))
worker.printer_thread = threading.Thread(
target=print_error_messages_raylet,
name="ray_print_error_messages",
args=(q, worker.threads_stopped))
args=(worker, worker.threads_stopped))
worker.listener_thread.daemon = True
worker.listener_thread.start()
worker.printer_thread.daemon = True
worker.printer_thread.start()
if log_to_driver:
global_worker_stdstream_dispatcher.add_handler(
"ray_print_logs", print_to_stdstream)
@@ -1332,8 +1295,6 @@ def disconnect(exiting_interpreter=False):
worker.import_thread.join_import_thread()
if hasattr(worker, "listener_thread"):
worker.listener_thread.join()
if hasattr(worker, "printer_thread"):
worker.printer_thread.join()
if hasattr(worker, "logger_thread"):
worker.logger_thread.join()
worker.threads_stopped.clear()
@@ -1445,13 +1406,11 @@ def get(object_refs, *, timeout=None):
raise ValueError("'object_refs' must either be an object ref "
"or a list of object refs.")
global last_task_error_raise_time
# TODO(ujvl): Consider how to allow user to retrieve the ready objects.
values, debugger_breakpoint = worker.get_objects(
object_refs, timeout=timeout)
for i, value in enumerate(values):
if isinstance(value, RayError):
last_task_error_raise_time = time.time()
if isinstance(value, ray.exceptions.ObjectLostError):
worker.core_worker.dump_object_store_memory_usage()
if isinstance(value, RayTaskError):
+8
View File
@@ -92,12 +92,20 @@ class RayObject {
/// large to return directly as part of a gRPC response).
bool IsInPlasmaError() const;
/// Mark this object as accessed before.
void SetAccessed() { accessed_ = true; };
/// Check if this object was accessed before.
bool WasAccessed() const { return accessed_; }
private:
std::shared_ptr<Buffer> data_;
std::shared_ptr<Buffer> metadata_;
const std::vector<ObjectID> nested_ids_;
/// Whether this class holds a data copy.
bool has_data_copy_;
/// Whether this object was accessed.
bool accessed_ = false;
};
} // namespace ray
+1 -1
View File
@@ -422,7 +422,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
return Status::OK();
},
options_.ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_,
options_.check_signals));
options_.check_signals, options_.unhandled_exception_handler));
auto check_node_alive_fn = [this](const NodeID &node_id) {
auto node = gcs_client_->Nodes().Get(node_id);
+3
View File
@@ -82,6 +82,7 @@ struct CoreWorkerOptions {
spill_objects(nullptr),
restore_spilled_objects(nullptr),
delete_spilled_objects(nullptr),
unhandled_exception_handler(nullptr),
get_lang_stack(nullptr),
kill_main(nullptr),
ref_counting_enabled(false),
@@ -146,6 +147,8 @@ struct CoreWorkerOptions {
/// Application-language callback to delete objects from external storage.
std::function<void(const std::vector<std::string> &, rpc::WorkerType)>
delete_spilled_objects;
/// Function to call on error objects never retrieved.
std::function<void(const RayObject &error)> unhandled_exception_handler;
/// Language worker callback to get the current call stack.
std::function<void(std::string *)> get_lang_stack;
// Function that tries to interrupt the currently running Python thread.
@@ -93,6 +93,7 @@ void GetRequest::Set(const ObjectID &object_id, std::shared_ptr<RayObject> objec
if (is_ready_) {
return; // We have already hit the number of objects to return limit.
}
object->SetAccessed();
objects_.emplace(object_id, object);
if (objects_.size() == num_objects_ ||
(abort_if_any_object_is_exception_ && object->IsException() &&
@@ -106,6 +107,7 @@ std::shared_ptr<RayObject> GetRequest::Get(const ObjectID &object_id) const {
std::unique_lock<std::mutex> lock(mutex_);
auto iter = objects_.find(object_id);
if (iter != objects_.end()) {
iter->second->SetAccessed();
return iter->second;
}
@@ -116,11 +118,13 @@ CoreWorkerMemoryStore::CoreWorkerMemoryStore(
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma,
std::shared_ptr<ReferenceCounter> counter,
std::shared_ptr<raylet::RayletClient> raylet_client,
std::function<Status()> check_signals)
std::function<Status()> check_signals,
std::function<void(const RayObject &)> unhandled_exception_handler)
: store_in_plasma_(store_in_plasma),
ref_counter_(counter),
raylet_client_(raylet_client),
check_signals_(check_signals) {}
check_signals_(check_signals),
unhandled_exception_handler_(unhandled_exception_handler) {}
void CoreWorkerMemoryStore::GetAsync(
const ObjectID &object_id, std::function<void(std::shared_ptr<RayObject>)> callback) {
@@ -136,6 +140,7 @@ void CoreWorkerMemoryStore::GetAsync(
}
// It's important for performance to run the callback outside the lock.
if (ptr != nullptr) {
ptr->SetAccessed();
callback(ptr);
}
}
@@ -146,6 +151,7 @@ std::shared_ptr<RayObject> CoreWorkerMemoryStore::GetOrPromoteToPlasma(
auto iter = objects_.find(object_id);
if (iter != objects_.end()) {
auto obj = iter->second;
obj->SetAccessed();
if (obj->IsInPlasmaError()) {
return nullptr;
}
@@ -210,6 +216,8 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_
if (should_add_entry) {
// If there is no existing get request, then add the `RayObject` to map.
objects_.emplace(object_id, object_entry);
} else {
OnErase(object_entry);
}
}
@@ -223,6 +231,7 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_
// It's important for performance to run the callbacks outside the lock.
for (const auto &cb : async_callbacks) {
object_entry->SetAccessed();
cb(object_entry);
}
@@ -257,6 +266,7 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector<ObjectID> &object_ids,
const auto &object_id = object_ids[i];
auto iter = objects_.find(object_id);
if (iter != objects_.end()) {
iter->second->SetAccessed();
(*results)[i] = iter->second;
if (remove_after_get) {
// Note that we cannot remove the object_id from `objects_` now,
@@ -426,6 +436,7 @@ void CoreWorkerMemoryStore::Delete(const absl::flat_hash_set<ObjectID> &object_i
if (it->second->IsInPlasmaError()) {
plasma_ids_to_delete->insert(object_id);
} else {
OnErase(it->second);
objects_.erase(it);
}
}
@@ -435,7 +446,11 @@ void CoreWorkerMemoryStore::Delete(const absl::flat_hash_set<ObjectID> &object_i
void CoreWorkerMemoryStore::Delete(const std::vector<ObjectID> &object_ids) {
absl::MutexLock lock(&mu_);
for (const auto &object_id : object_ids) {
objects_.erase(object_id);
auto it = objects_.find(object_id);
if (it != objects_.end()) {
OnErase(it->second);
objects_.erase(it);
}
}
}
@@ -451,6 +466,14 @@ bool CoreWorkerMemoryStore::Contains(const ObjectID &object_id, bool *in_plasma)
return false;
}
void CoreWorkerMemoryStore::OnErase(std::shared_ptr<RayObject> obj) {
// TODO(ekl) note that this doesn't warn on errors that are stored in plasma.
if (obj->IsException() && !obj->IsInPlasmaError() && !obj->WasAccessed() &&
unhandled_exception_handler_ != nullptr) {
unhandled_exception_handler_(*obj);
}
}
MemoryStoreStats CoreWorkerMemoryStore::GetMemoryStoreStatisticalData() {
absl::MutexLock lock(&mu_);
MemoryStoreStats item;
@@ -35,7 +35,8 @@ class CoreWorkerMemoryStore {
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma = nullptr,
std::shared_ptr<ReferenceCounter> counter = nullptr,
std::shared_ptr<raylet::RayletClient> raylet_client = nullptr,
std::function<Status()> check_signals = nullptr);
std::function<Status()> check_signals = nullptr,
std::function<void(const RayObject &)> unhandled_exception_handler = nullptr);
~CoreWorkerMemoryStore(){};
/// Put an object with specified ID into object store.
@@ -143,6 +144,9 @@ class CoreWorkerMemoryStore {
std::vector<std::shared_ptr<RayObject>> *results,
bool abort_if_any_object_is_exception);
/// Called when an object is erased from the store.
void OnErase(std::shared_ptr<RayObject> obj);
/// Optional callback for putting objects into the plasma store.
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma_;
@@ -173,6 +177,9 @@ class CoreWorkerMemoryStore {
/// Function passed in to be called to check for signals (e.g., Ctrl-C).
std::function<Status()> check_signals_;
/// Function called to report unhandled exceptions.
std::function<void(const RayObject &)> unhandled_exception_handler_;
};
} // namespace ray
@@ -0,0 +1,66 @@
// Copyright 2017 The Ray Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
#include "gtest/gtest.h"
#include "ray/common/test_util.h"
namespace ray {
TEST(TestMemoryStore, TestReportUnhandledErrors) {
std::vector<std::shared_ptr<RayObject>> results;
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0));
int unhandled_count = 0;
std::shared_ptr<CoreWorkerMemoryStore> provider =
std::make_shared<CoreWorkerMemoryStore>(
nullptr, nullptr, nullptr, nullptr,
[&](const RayObject &obj) { unhandled_count++; });
RayObject obj1(rpc::ErrorType::TASK_EXECUTION_EXCEPTION);
RayObject obj2(rpc::ErrorType::TASK_EXECUTION_EXCEPTION);
auto id1 = ObjectID::FromRandom();
auto id2 = ObjectID::FromRandom();
// Check delete without get.
RAY_CHECK(provider->Put(obj1, id1));
RAY_CHECK(provider->Put(obj2, id2));
ASSERT_EQ(unhandled_count, 0);
provider->Delete({id1, id2});
ASSERT_EQ(unhandled_count, 2);
unhandled_count = 0;
// Check delete after get.
RAY_CHECK(provider->Put(obj1, id1));
RAY_CHECK(provider->Put(obj1, id2));
provider->Get({id1}, 1, 100, context, false, &results);
provider->GetOrPromoteToPlasma(id2);
provider->Delete({id1, id2});
ASSERT_EQ(unhandled_count, 0);
// Check delete after async get.
provider->GetAsync({id2}, [](std::shared_ptr<RayObject> obj) {});
RAY_CHECK(provider->Put(obj1, id1));
RAY_CHECK(provider->Put(obj2, id2));
provider->GetAsync({id1}, [](std::shared_ptr<RayObject> obj) {});
provider->Delete({id1, id2});
ASSERT_EQ(unhandled_count, 0);
}
} // namespace ray
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}