mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 16:46:37 +08:00
Revert "Revert "Unhandled exception handler based on local ref counti… (#14113)
* Revert "Revert "Unhandled exception handler based on local ref counting (#14049)" (#14099)"
This reverts commit b45ae76765.
* reomve test
* fix
* fix
This commit is contained in:
+10
@@ -702,6 +702,16 @@ cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "memory_store_test",
|
||||||
|
srcs = ["src/ray/core_worker/test/memory_store_test.cc"],
|
||||||
|
copts = COPTS,
|
||||||
|
deps = [
|
||||||
|
":core_worker_lib",
|
||||||
|
"@com_google_googletest//:gtest_main",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "direct_actor_transport_test",
|
name = "direct_actor_transport_test",
|
||||||
srcs = ["src/ray/core_worker/test/direct_actor_transport_test.cc"],
|
srcs = ["src/ray/core_worker/test/direct_actor_transport_test.cc"],
|
||||||
|
|||||||
+22
-3
@@ -724,6 +724,20 @@ cdef void delete_spilled_objects_handler(
|
|||||||
job_id=None)
|
job_id=None)
|
||||||
|
|
||||||
|
|
||||||
|
cdef void unhandled_exception_handler(const CRayObject& error) nogil:
|
||||||
|
with gil:
|
||||||
|
worker = ray.worker.global_worker
|
||||||
|
data = None
|
||||||
|
metadata = None
|
||||||
|
if error.HasData():
|
||||||
|
data = Buffer.make(error.GetData())
|
||||||
|
if error.HasMetadata():
|
||||||
|
metadata = Buffer.make(error.GetMetadata()).to_pybytes()
|
||||||
|
# TODO(ekl) why does passing a ObjectRef.nil() lead to shutdown errors?
|
||||||
|
object_ids = [None]
|
||||||
|
worker.raise_errors([(data, metadata)], object_ids)
|
||||||
|
|
||||||
|
|
||||||
# This function introduces ~2-7us of overhead per call (i.e., it can be called
|
# This function introduces ~2-7us of overhead per call (i.e., it can be called
|
||||||
# up to hundreds of thousands of times per second).
|
# up to hundreds of thousands of times per second).
|
||||||
cdef void get_py_stack(c_string* stack_out) nogil:
|
cdef void get_py_stack(c_string* stack_out) nogil:
|
||||||
@@ -833,6 +847,7 @@ cdef class CoreWorker:
|
|||||||
options.spill_objects = spill_objects_handler
|
options.spill_objects = spill_objects_handler
|
||||||
options.restore_spilled_objects = restore_spilled_objects_handler
|
options.restore_spilled_objects = restore_spilled_objects_handler
|
||||||
options.delete_spilled_objects = delete_spilled_objects_handler
|
options.delete_spilled_objects = delete_spilled_objects_handler
|
||||||
|
options.unhandled_exception_handler = unhandled_exception_handler
|
||||||
options.get_lang_stack = get_py_stack
|
options.get_lang_stack = get_py_stack
|
||||||
options.ref_counting_enabled = True
|
options.ref_counting_enabled = True
|
||||||
options.is_local_mode = local_mode
|
options.is_local_mode = local_mode
|
||||||
@@ -1443,9 +1458,13 @@ cdef class CoreWorker:
|
|||||||
object_ref.native())
|
object_ref.native())
|
||||||
|
|
||||||
def remove_object_ref_reference(self, ObjectRef object_ref):
|
def remove_object_ref_reference(self, ObjectRef object_ref):
|
||||||
# Note: faster to not release GIL for short-running op.
|
cdef:
|
||||||
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
|
CObjectID c_object_id = object_ref.native()
|
||||||
object_ref.native())
|
# We need to release the gil since object destruction may call the
|
||||||
|
# unhandled exception handler.
|
||||||
|
with nogil:
|
||||||
|
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
|
||||||
|
c_object_id)
|
||||||
|
|
||||||
def serialize_and_promote_object_ref(self, ObjectRef object_ref):
|
def serialize_and_promote_object_ref(self, ObjectRef object_ref):
|
||||||
cdef:
|
cdef:
|
||||||
|
|||||||
@@ -250,6 +250,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
|||||||
(void(
|
(void(
|
||||||
const c_vector[c_string]&,
|
const c_vector[c_string]&,
|
||||||
CWorkerType) nogil) delete_spilled_objects
|
CWorkerType) nogil) delete_spilled_objects
|
||||||
|
(void(const CRayObject&) nogil) unhandled_exception_handler
|
||||||
(void(c_string *stack_out) nogil) get_lang_stack
|
(void(c_string *stack_out) nogil) get_lang_stack
|
||||||
c_bool ref_counting_enabled
|
c_bool ref_counting_enabled
|
||||||
c_bool is_local_mode
|
c_bool is_local_mode
|
||||||
|
|||||||
@@ -20,6 +20,52 @@ from ray.test_utils import (wait_for_condition, SignalActor, init_error_pubsub,
|
|||||||
get_error_message, Semaphore)
|
get_error_message, Semaphore)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unhandled_errors(ray_start_regular):
|
||||||
|
@ray.remote
|
||||||
|
def f():
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
class Actor:
|
||||||
|
def f(self):
|
||||||
|
raise ValueError()
|
||||||
|
|
||||||
|
a = Actor.remote()
|
||||||
|
num_exceptions = 0
|
||||||
|
|
||||||
|
def interceptor(e):
|
||||||
|
nonlocal num_exceptions
|
||||||
|
num_exceptions += 1
|
||||||
|
|
||||||
|
# Test we report unhandled exceptions.
|
||||||
|
ray.worker._unhandled_error_handler = interceptor
|
||||||
|
x1 = f.remote()
|
||||||
|
x2 = a.f.remote()
|
||||||
|
del x1
|
||||||
|
del x2
|
||||||
|
wait_for_condition(lambda: num_exceptions == 2)
|
||||||
|
|
||||||
|
# Test we don't report handled exceptions.
|
||||||
|
x1 = f.remote()
|
||||||
|
x2 = a.f.remote()
|
||||||
|
with pytest.raises(ray.exceptions.RayError) as err: # noqa
|
||||||
|
ray.get([x1, x2])
|
||||||
|
del x1
|
||||||
|
del x2
|
||||||
|
time.sleep(1)
|
||||||
|
assert num_exceptions == 2, num_exceptions
|
||||||
|
|
||||||
|
# Test suppression with env var works.
|
||||||
|
try:
|
||||||
|
os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
|
||||||
|
x1 = f.remote()
|
||||||
|
del x1
|
||||||
|
time.sleep(1)
|
||||||
|
assert num_exceptions == 2, num_exceptions
|
||||||
|
finally:
|
||||||
|
del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"]
|
||||||
|
|
||||||
|
|
||||||
def test_failed_task(ray_start_regular, error_pubsub):
|
def test_failed_task(ray_start_regular, error_pubsub):
|
||||||
@ray.remote
|
@ray.remote
|
||||||
def throw_exception_fct1():
|
def throw_exception_fct1():
|
||||||
|
|||||||
+19
-60
@@ -9,7 +9,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import redis
|
import redis
|
||||||
from six.moves import queue
|
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@@ -69,6 +68,12 @@ ERROR_KEY_PREFIX = b"Error:"
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Visible for testing.
|
||||||
|
def _unhandled_error_handler(e: Exception):
|
||||||
|
logger.error("Unhandled error (suppress with "
|
||||||
|
"RAY_IGNORE_UNHANDLED_ERRORS=1): {}".format(e))
|
||||||
|
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
"""A class used to define the control flow of a worker process.
|
"""A class used to define the control flow of a worker process.
|
||||||
|
|
||||||
@@ -277,6 +282,14 @@ class Worker:
|
|||||||
self.core_worker.put_serialized_object(
|
self.core_worker.put_serialized_object(
|
||||||
serialized_value, object_ref=object_ref))
|
serialized_value, object_ref=object_ref))
|
||||||
|
|
||||||
|
def raise_errors(self, data_metadata_pairs, object_refs):
|
||||||
|
context = self.get_serialization_context()
|
||||||
|
out = context.deserialize_objects(data_metadata_pairs, object_refs)
|
||||||
|
if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ:
|
||||||
|
return
|
||||||
|
for e in out:
|
||||||
|
_unhandled_error_handler(e)
|
||||||
|
|
||||||
def deserialize_objects(self, data_metadata_pairs, object_refs):
|
def deserialize_objects(self, data_metadata_pairs, object_refs):
|
||||||
context = self.get_serialization_context()
|
context = self.get_serialization_context()
|
||||||
return context.deserialize_objects(data_metadata_pairs, object_refs)
|
return context.deserialize_objects(data_metadata_pairs, object_refs)
|
||||||
@@ -863,13 +876,6 @@ def custom_excepthook(type, value, tb):
|
|||||||
|
|
||||||
sys.excepthook = custom_excepthook
|
sys.excepthook = custom_excepthook
|
||||||
|
|
||||||
# The last time we raised a TaskError in this process. We use this value to
|
|
||||||
# suppress redundant error messages pushed from the workers.
|
|
||||||
last_task_error_raise_time = 0
|
|
||||||
|
|
||||||
# The max amount of seconds to wait before printing out an uncaught error.
|
|
||||||
UNCAUGHT_ERROR_GRACE_PERIOD = 5
|
|
||||||
|
|
||||||
|
|
||||||
def print_logs(redis_client, threads_stopped, job_id):
|
def print_logs(redis_client, threads_stopped, job_id):
|
||||||
"""Prints log messages from workers on all of the nodes.
|
"""Prints log messages from workers on all of the nodes.
|
||||||
@@ -1020,42 +1026,7 @@ def print_worker_logs(data: Dict[str, str], print_file: Any):
|
|||||||
file=print_file)
|
file=print_file)
|
||||||
|
|
||||||
|
|
||||||
def print_error_messages_raylet(task_error_queue, threads_stopped):
|
def listen_error_messages_raylet(worker, threads_stopped):
|
||||||
"""Prints message received in the given output queue.
|
|
||||||
|
|
||||||
This checks periodically if any un-raised errors occurred in the
|
|
||||||
background.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_error_queue (queue.Queue): A queue used to receive errors from the
|
|
||||||
thread that listens to Redis.
|
|
||||||
threads_stopped (threading.Event): A threading event used to signal to
|
|
||||||
the thread that it should exit.
|
|
||||||
"""
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# Exit if we received a signal that we should stop.
|
|
||||||
if threads_stopped.is_set():
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
error, t = task_error_queue.get(block=False)
|
|
||||||
except queue.Empty:
|
|
||||||
threads_stopped.wait(timeout=0.01)
|
|
||||||
continue
|
|
||||||
# Delay errors a little bit of time to attempt to suppress redundant
|
|
||||||
# messages originating from the worker.
|
|
||||||
while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time():
|
|
||||||
threads_stopped.wait(timeout=1)
|
|
||||||
if threads_stopped.is_set():
|
|
||||||
break
|
|
||||||
if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD:
|
|
||||||
logger.debug(f"Suppressing error from worker: {error}")
|
|
||||||
else:
|
|
||||||
logger.error(f"Possible unhandled error from worker: {error}")
|
|
||||||
|
|
||||||
|
|
||||||
def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
|
|
||||||
"""Listen to error messages in the background on the driver.
|
"""Listen to error messages in the background on the driver.
|
||||||
|
|
||||||
This runs in a separate thread on the driver and pushes (error, time)
|
This runs in a separate thread on the driver and pushes (error, time)
|
||||||
@@ -1063,8 +1034,6 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
worker: The worker class that this thread belongs to.
|
worker: The worker class that this thread belongs to.
|
||||||
task_error_queue (queue.Queue): A queue used to communicate with the
|
|
||||||
thread that prints the errors found by this thread.
|
|
||||||
threads_stopped (threading.Event): A threading event used to signal to
|
threads_stopped (threading.Event): A threading event used to signal to
|
||||||
the thread that it should exit.
|
the thread that it should exit.
|
||||||
"""
|
"""
|
||||||
@@ -1103,8 +1072,9 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
|
|||||||
|
|
||||||
error_message = error_data.error_message
|
error_message = error_data.error_message
|
||||||
if (error_data.type == ray_constants.TASK_PUSH_ERROR):
|
if (error_data.type == ray_constants.TASK_PUSH_ERROR):
|
||||||
# Delay it a bit to see if we can suppress it
|
# TODO(ekl) remove task push errors entirely now that we have
|
||||||
task_error_queue.put((error_message, time.time()))
|
# the separate unhandled exception handler.
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
logger.warning(error_message)
|
logger.warning(error_message)
|
||||||
except (OSError, redis.exceptions.ConnectionError) as e:
|
except (OSError, redis.exceptions.ConnectionError) as e:
|
||||||
@@ -1267,19 +1237,12 @@ def connect(node,
|
|||||||
# temporarily using this implementation which constantly queries the
|
# temporarily using this implementation which constantly queries the
|
||||||
# scheduler for new error messages.
|
# scheduler for new error messages.
|
||||||
if mode == SCRIPT_MODE:
|
if mode == SCRIPT_MODE:
|
||||||
q = queue.Queue()
|
|
||||||
worker.listener_thread = threading.Thread(
|
worker.listener_thread = threading.Thread(
|
||||||
target=listen_error_messages_raylet,
|
target=listen_error_messages_raylet,
|
||||||
name="ray_listen_error_messages",
|
name="ray_listen_error_messages",
|
||||||
args=(worker, q, worker.threads_stopped))
|
args=(worker, worker.threads_stopped))
|
||||||
worker.printer_thread = threading.Thread(
|
|
||||||
target=print_error_messages_raylet,
|
|
||||||
name="ray_print_error_messages",
|
|
||||||
args=(q, worker.threads_stopped))
|
|
||||||
worker.listener_thread.daemon = True
|
worker.listener_thread.daemon = True
|
||||||
worker.listener_thread.start()
|
worker.listener_thread.start()
|
||||||
worker.printer_thread.daemon = True
|
|
||||||
worker.printer_thread.start()
|
|
||||||
if log_to_driver:
|
if log_to_driver:
|
||||||
global_worker_stdstream_dispatcher.add_handler(
|
global_worker_stdstream_dispatcher.add_handler(
|
||||||
"ray_print_logs", print_to_stdstream)
|
"ray_print_logs", print_to_stdstream)
|
||||||
@@ -1332,8 +1295,6 @@ def disconnect(exiting_interpreter=False):
|
|||||||
worker.import_thread.join_import_thread()
|
worker.import_thread.join_import_thread()
|
||||||
if hasattr(worker, "listener_thread"):
|
if hasattr(worker, "listener_thread"):
|
||||||
worker.listener_thread.join()
|
worker.listener_thread.join()
|
||||||
if hasattr(worker, "printer_thread"):
|
|
||||||
worker.printer_thread.join()
|
|
||||||
if hasattr(worker, "logger_thread"):
|
if hasattr(worker, "logger_thread"):
|
||||||
worker.logger_thread.join()
|
worker.logger_thread.join()
|
||||||
worker.threads_stopped.clear()
|
worker.threads_stopped.clear()
|
||||||
@@ -1445,13 +1406,11 @@ def get(object_refs, *, timeout=None):
|
|||||||
raise ValueError("'object_refs' must either be an object ref "
|
raise ValueError("'object_refs' must either be an object ref "
|
||||||
"or a list of object refs.")
|
"or a list of object refs.")
|
||||||
|
|
||||||
global last_task_error_raise_time
|
|
||||||
# TODO(ujvl): Consider how to allow user to retrieve the ready objects.
|
# TODO(ujvl): Consider how to allow user to retrieve the ready objects.
|
||||||
values, debugger_breakpoint = worker.get_objects(
|
values, debugger_breakpoint = worker.get_objects(
|
||||||
object_refs, timeout=timeout)
|
object_refs, timeout=timeout)
|
||||||
for i, value in enumerate(values):
|
for i, value in enumerate(values):
|
||||||
if isinstance(value, RayError):
|
if isinstance(value, RayError):
|
||||||
last_task_error_raise_time = time.time()
|
|
||||||
if isinstance(value, ray.exceptions.ObjectLostError):
|
if isinstance(value, ray.exceptions.ObjectLostError):
|
||||||
worker.core_worker.dump_object_store_memory_usage()
|
worker.core_worker.dump_object_store_memory_usage()
|
||||||
if isinstance(value, RayTaskError):
|
if isinstance(value, RayTaskError):
|
||||||
|
|||||||
@@ -92,12 +92,20 @@ class RayObject {
|
|||||||
/// large to return directly as part of a gRPC response).
|
/// large to return directly as part of a gRPC response).
|
||||||
bool IsInPlasmaError() const;
|
bool IsInPlasmaError() const;
|
||||||
|
|
||||||
|
/// Mark this object as accessed before.
|
||||||
|
void SetAccessed() { accessed_ = true; };
|
||||||
|
|
||||||
|
/// Check if this object was accessed before.
|
||||||
|
bool WasAccessed() const { return accessed_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<Buffer> data_;
|
std::shared_ptr<Buffer> data_;
|
||||||
std::shared_ptr<Buffer> metadata_;
|
std::shared_ptr<Buffer> metadata_;
|
||||||
const std::vector<ObjectID> nested_ids_;
|
const std::vector<ObjectID> nested_ids_;
|
||||||
/// Whether this class holds a data copy.
|
/// Whether this class holds a data copy.
|
||||||
bool has_data_copy_;
|
bool has_data_copy_;
|
||||||
|
/// Whether this object was accessed.
|
||||||
|
bool accessed_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ray
|
} // namespace ray
|
||||||
|
|||||||
@@ -422,7 +422,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
},
|
},
|
||||||
options_.ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_,
|
options_.ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_,
|
||||||
options_.check_signals));
|
options_.check_signals, options_.unhandled_exception_handler));
|
||||||
|
|
||||||
auto check_node_alive_fn = [this](const NodeID &node_id) {
|
auto check_node_alive_fn = [this](const NodeID &node_id) {
|
||||||
auto node = gcs_client_->Nodes().Get(node_id);
|
auto node = gcs_client_->Nodes().Get(node_id);
|
||||||
|
|||||||
@@ -82,6 +82,7 @@ struct CoreWorkerOptions {
|
|||||||
spill_objects(nullptr),
|
spill_objects(nullptr),
|
||||||
restore_spilled_objects(nullptr),
|
restore_spilled_objects(nullptr),
|
||||||
delete_spilled_objects(nullptr),
|
delete_spilled_objects(nullptr),
|
||||||
|
unhandled_exception_handler(nullptr),
|
||||||
get_lang_stack(nullptr),
|
get_lang_stack(nullptr),
|
||||||
kill_main(nullptr),
|
kill_main(nullptr),
|
||||||
ref_counting_enabled(false),
|
ref_counting_enabled(false),
|
||||||
@@ -146,6 +147,8 @@ struct CoreWorkerOptions {
|
|||||||
/// Application-language callback to delete objects from external storage.
|
/// Application-language callback to delete objects from external storage.
|
||||||
std::function<void(const std::vector<std::string> &, rpc::WorkerType)>
|
std::function<void(const std::vector<std::string> &, rpc::WorkerType)>
|
||||||
delete_spilled_objects;
|
delete_spilled_objects;
|
||||||
|
/// Function to call on error objects never retrieved.
|
||||||
|
std::function<void(const RayObject &error)> unhandled_exception_handler;
|
||||||
/// Language worker callback to get the current call stack.
|
/// Language worker callback to get the current call stack.
|
||||||
std::function<void(std::string *)> get_lang_stack;
|
std::function<void(std::string *)> get_lang_stack;
|
||||||
// Function that tries to interrupt the currently running Python thread.
|
// Function that tries to interrupt the currently running Python thread.
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ void GetRequest::Set(const ObjectID &object_id, std::shared_ptr<RayObject> objec
|
|||||||
if (is_ready_) {
|
if (is_ready_) {
|
||||||
return; // We have already hit the number of objects to return limit.
|
return; // We have already hit the number of objects to return limit.
|
||||||
}
|
}
|
||||||
|
object->SetAccessed();
|
||||||
objects_.emplace(object_id, object);
|
objects_.emplace(object_id, object);
|
||||||
if (objects_.size() == num_objects_ ||
|
if (objects_.size() == num_objects_ ||
|
||||||
(abort_if_any_object_is_exception_ && object->IsException() &&
|
(abort_if_any_object_is_exception_ && object->IsException() &&
|
||||||
@@ -106,6 +107,7 @@ std::shared_ptr<RayObject> GetRequest::Get(const ObjectID &object_id) const {
|
|||||||
std::unique_lock<std::mutex> lock(mutex_);
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
auto iter = objects_.find(object_id);
|
auto iter = objects_.find(object_id);
|
||||||
if (iter != objects_.end()) {
|
if (iter != objects_.end()) {
|
||||||
|
iter->second->SetAccessed();
|
||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,11 +118,13 @@ CoreWorkerMemoryStore::CoreWorkerMemoryStore(
|
|||||||
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma,
|
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma,
|
||||||
std::shared_ptr<ReferenceCounter> counter,
|
std::shared_ptr<ReferenceCounter> counter,
|
||||||
std::shared_ptr<raylet::RayletClient> raylet_client,
|
std::shared_ptr<raylet::RayletClient> raylet_client,
|
||||||
std::function<Status()> check_signals)
|
std::function<Status()> check_signals,
|
||||||
|
std::function<void(const RayObject &)> unhandled_exception_handler)
|
||||||
: store_in_plasma_(store_in_plasma),
|
: store_in_plasma_(store_in_plasma),
|
||||||
ref_counter_(counter),
|
ref_counter_(counter),
|
||||||
raylet_client_(raylet_client),
|
raylet_client_(raylet_client),
|
||||||
check_signals_(check_signals) {}
|
check_signals_(check_signals),
|
||||||
|
unhandled_exception_handler_(unhandled_exception_handler) {}
|
||||||
|
|
||||||
void CoreWorkerMemoryStore::GetAsync(
|
void CoreWorkerMemoryStore::GetAsync(
|
||||||
const ObjectID &object_id, std::function<void(std::shared_ptr<RayObject>)> callback) {
|
const ObjectID &object_id, std::function<void(std::shared_ptr<RayObject>)> callback) {
|
||||||
@@ -136,6 +140,7 @@ void CoreWorkerMemoryStore::GetAsync(
|
|||||||
}
|
}
|
||||||
// It's important for performance to run the callback outside the lock.
|
// It's important for performance to run the callback outside the lock.
|
||||||
if (ptr != nullptr) {
|
if (ptr != nullptr) {
|
||||||
|
ptr->SetAccessed();
|
||||||
callback(ptr);
|
callback(ptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -146,6 +151,7 @@ std::shared_ptr<RayObject> CoreWorkerMemoryStore::GetOrPromoteToPlasma(
|
|||||||
auto iter = objects_.find(object_id);
|
auto iter = objects_.find(object_id);
|
||||||
if (iter != objects_.end()) {
|
if (iter != objects_.end()) {
|
||||||
auto obj = iter->second;
|
auto obj = iter->second;
|
||||||
|
obj->SetAccessed();
|
||||||
if (obj->IsInPlasmaError()) {
|
if (obj->IsInPlasmaError()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
@@ -210,6 +216,8 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_
|
|||||||
if (should_add_entry) {
|
if (should_add_entry) {
|
||||||
// If there is no existing get request, then add the `RayObject` to map.
|
// If there is no existing get request, then add the `RayObject` to map.
|
||||||
objects_.emplace(object_id, object_entry);
|
objects_.emplace(object_id, object_entry);
|
||||||
|
} else {
|
||||||
|
OnErase(object_entry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -223,6 +231,7 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_
|
|||||||
|
|
||||||
// It's important for performance to run the callbacks outside the lock.
|
// It's important for performance to run the callbacks outside the lock.
|
||||||
for (const auto &cb : async_callbacks) {
|
for (const auto &cb : async_callbacks) {
|
||||||
|
object_entry->SetAccessed();
|
||||||
cb(object_entry);
|
cb(object_entry);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -257,6 +266,7 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector<ObjectID> &object_ids,
|
|||||||
const auto &object_id = object_ids[i];
|
const auto &object_id = object_ids[i];
|
||||||
auto iter = objects_.find(object_id);
|
auto iter = objects_.find(object_id);
|
||||||
if (iter != objects_.end()) {
|
if (iter != objects_.end()) {
|
||||||
|
iter->second->SetAccessed();
|
||||||
(*results)[i] = iter->second;
|
(*results)[i] = iter->second;
|
||||||
if (remove_after_get) {
|
if (remove_after_get) {
|
||||||
// Note that we cannot remove the object_id from `objects_` now,
|
// Note that we cannot remove the object_id from `objects_` now,
|
||||||
@@ -426,6 +436,7 @@ void CoreWorkerMemoryStore::Delete(const absl::flat_hash_set<ObjectID> &object_i
|
|||||||
if (it->second->IsInPlasmaError()) {
|
if (it->second->IsInPlasmaError()) {
|
||||||
plasma_ids_to_delete->insert(object_id);
|
plasma_ids_to_delete->insert(object_id);
|
||||||
} else {
|
} else {
|
||||||
|
OnErase(it->second);
|
||||||
objects_.erase(it);
|
objects_.erase(it);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -435,7 +446,11 @@ void CoreWorkerMemoryStore::Delete(const absl::flat_hash_set<ObjectID> &object_i
|
|||||||
void CoreWorkerMemoryStore::Delete(const std::vector<ObjectID> &object_ids) {
|
void CoreWorkerMemoryStore::Delete(const std::vector<ObjectID> &object_ids) {
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
for (const auto &object_id : object_ids) {
|
for (const auto &object_id : object_ids) {
|
||||||
objects_.erase(object_id);
|
auto it = objects_.find(object_id);
|
||||||
|
if (it != objects_.end()) {
|
||||||
|
OnErase(it->second);
|
||||||
|
objects_.erase(it);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -451,6 +466,14 @@ bool CoreWorkerMemoryStore::Contains(const ObjectID &object_id, bool *in_plasma)
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void CoreWorkerMemoryStore::OnErase(std::shared_ptr<RayObject> obj) {
|
||||||
|
// TODO(ekl) note that this doesn't warn on errors that are stored in plasma.
|
||||||
|
if (obj->IsException() && !obj->IsInPlasmaError() && !obj->WasAccessed() &&
|
||||||
|
unhandled_exception_handler_ != nullptr) {
|
||||||
|
unhandled_exception_handler_(*obj);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
MemoryStoreStats CoreWorkerMemoryStore::GetMemoryStoreStatisticalData() {
|
MemoryStoreStats CoreWorkerMemoryStore::GetMemoryStoreStatisticalData() {
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
MemoryStoreStats item;
|
MemoryStoreStats item;
|
||||||
|
|||||||
@@ -35,7 +35,8 @@ class CoreWorkerMemoryStore {
|
|||||||
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma = nullptr,
|
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma = nullptr,
|
||||||
std::shared_ptr<ReferenceCounter> counter = nullptr,
|
std::shared_ptr<ReferenceCounter> counter = nullptr,
|
||||||
std::shared_ptr<raylet::RayletClient> raylet_client = nullptr,
|
std::shared_ptr<raylet::RayletClient> raylet_client = nullptr,
|
||||||
std::function<Status()> check_signals = nullptr);
|
std::function<Status()> check_signals = nullptr,
|
||||||
|
std::function<void(const RayObject &)> unhandled_exception_handler = nullptr);
|
||||||
~CoreWorkerMemoryStore(){};
|
~CoreWorkerMemoryStore(){};
|
||||||
|
|
||||||
/// Put an object with specified ID into object store.
|
/// Put an object with specified ID into object store.
|
||||||
@@ -143,6 +144,9 @@ class CoreWorkerMemoryStore {
|
|||||||
std::vector<std::shared_ptr<RayObject>> *results,
|
std::vector<std::shared_ptr<RayObject>> *results,
|
||||||
bool abort_if_any_object_is_exception);
|
bool abort_if_any_object_is_exception);
|
||||||
|
|
||||||
|
/// Called when an object is erased from the store.
|
||||||
|
void OnErase(std::shared_ptr<RayObject> obj);
|
||||||
|
|
||||||
/// Optional callback for putting objects into the plasma store.
|
/// Optional callback for putting objects into the plasma store.
|
||||||
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma_;
|
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma_;
|
||||||
|
|
||||||
@@ -173,6 +177,9 @@ class CoreWorkerMemoryStore {
|
|||||||
|
|
||||||
/// Function passed in to be called to check for signals (e.g., Ctrl-C).
|
/// Function passed in to be called to check for signals (e.g., Ctrl-C).
|
||||||
std::function<Status()> check_signals_;
|
std::function<Status()> check_signals_;
|
||||||
|
|
||||||
|
/// Function called to report unhandled exceptions.
|
||||||
|
std::function<void(const RayObject &)> unhandled_exception_handler_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ray
|
} // namespace ray
|
||||||
|
|||||||
@@ -0,0 +1,66 @@
|
|||||||
|
// Copyright 2017 The Ray Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||||
|
|
||||||
|
#include "gtest/gtest.h"
|
||||||
|
#include "ray/common/test_util.h"
|
||||||
|
|
||||||
|
namespace ray {
|
||||||
|
|
||||||
|
TEST(TestMemoryStore, TestReportUnhandledErrors) {
|
||||||
|
std::vector<std::shared_ptr<RayObject>> results;
|
||||||
|
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0));
|
||||||
|
int unhandled_count = 0;
|
||||||
|
|
||||||
|
std::shared_ptr<CoreWorkerMemoryStore> provider =
|
||||||
|
std::make_shared<CoreWorkerMemoryStore>(
|
||||||
|
nullptr, nullptr, nullptr, nullptr,
|
||||||
|
[&](const RayObject &obj) { unhandled_count++; });
|
||||||
|
RayObject obj1(rpc::ErrorType::TASK_EXECUTION_EXCEPTION);
|
||||||
|
RayObject obj2(rpc::ErrorType::TASK_EXECUTION_EXCEPTION);
|
||||||
|
auto id1 = ObjectID::FromRandom();
|
||||||
|
auto id2 = ObjectID::FromRandom();
|
||||||
|
|
||||||
|
// Check delete without get.
|
||||||
|
RAY_CHECK(provider->Put(obj1, id1));
|
||||||
|
RAY_CHECK(provider->Put(obj2, id2));
|
||||||
|
ASSERT_EQ(unhandled_count, 0);
|
||||||
|
provider->Delete({id1, id2});
|
||||||
|
ASSERT_EQ(unhandled_count, 2);
|
||||||
|
unhandled_count = 0;
|
||||||
|
|
||||||
|
// Check delete after get.
|
||||||
|
RAY_CHECK(provider->Put(obj1, id1));
|
||||||
|
RAY_CHECK(provider->Put(obj1, id2));
|
||||||
|
provider->Get({id1}, 1, 100, context, false, &results);
|
||||||
|
provider->GetOrPromoteToPlasma(id2);
|
||||||
|
provider->Delete({id1, id2});
|
||||||
|
ASSERT_EQ(unhandled_count, 0);
|
||||||
|
|
||||||
|
// Check delete after async get.
|
||||||
|
provider->GetAsync({id2}, [](std::shared_ptr<RayObject> obj) {});
|
||||||
|
RAY_CHECK(provider->Put(obj1, id1));
|
||||||
|
RAY_CHECK(provider->Put(obj2, id2));
|
||||||
|
provider->GetAsync({id1}, [](std::shared_ptr<RayObject> obj) {});
|
||||||
|
provider->Delete({id1, id2});
|
||||||
|
ASSERT_EQ(unhandled_count, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ray
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
::testing::InitGoogleTest(&argc, argv);
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user