mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 17:01:55 +08:00
Revert "Revert "Unhandled exception handler based on local ref counti… (#14113)
* Revert "Revert "Unhandled exception handler based on local ref counting (#14049)" (#14099)"
This reverts commit b45ae76765.
* reomve test
* fix
* fix
This commit is contained in:
+10
@@ -702,6 +702,16 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "memory_store_test",
|
||||
srcs = ["src/ray/core_worker/test/memory_store_test.cc"],
|
||||
copts = COPTS,
|
||||
deps = [
|
||||
":core_worker_lib",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "direct_actor_transport_test",
|
||||
srcs = ["src/ray/core_worker/test/direct_actor_transport_test.cc"],
|
||||
|
||||
+22
-3
@@ -724,6 +724,20 @@ cdef void delete_spilled_objects_handler(
|
||||
job_id=None)
|
||||
|
||||
|
||||
cdef void unhandled_exception_handler(const CRayObject& error) nogil:
|
||||
with gil:
|
||||
worker = ray.worker.global_worker
|
||||
data = None
|
||||
metadata = None
|
||||
if error.HasData():
|
||||
data = Buffer.make(error.GetData())
|
||||
if error.HasMetadata():
|
||||
metadata = Buffer.make(error.GetMetadata()).to_pybytes()
|
||||
# TODO(ekl) why does passing a ObjectRef.nil() lead to shutdown errors?
|
||||
object_ids = [None]
|
||||
worker.raise_errors([(data, metadata)], object_ids)
|
||||
|
||||
|
||||
# This function introduces ~2-7us of overhead per call (i.e., it can be called
|
||||
# up to hundreds of thousands of times per second).
|
||||
cdef void get_py_stack(c_string* stack_out) nogil:
|
||||
@@ -833,6 +847,7 @@ cdef class CoreWorker:
|
||||
options.spill_objects = spill_objects_handler
|
||||
options.restore_spilled_objects = restore_spilled_objects_handler
|
||||
options.delete_spilled_objects = delete_spilled_objects_handler
|
||||
options.unhandled_exception_handler = unhandled_exception_handler
|
||||
options.get_lang_stack = get_py_stack
|
||||
options.ref_counting_enabled = True
|
||||
options.is_local_mode = local_mode
|
||||
@@ -1443,9 +1458,13 @@ cdef class CoreWorker:
|
||||
object_ref.native())
|
||||
|
||||
def remove_object_ref_reference(self, ObjectRef object_ref):
|
||||
# Note: faster to not release GIL for short-running op.
|
||||
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
|
||||
object_ref.native())
|
||||
cdef:
|
||||
CObjectID c_object_id = object_ref.native()
|
||||
# We need to release the gil since object destruction may call the
|
||||
# unhandled exception handler.
|
||||
with nogil:
|
||||
CCoreWorkerProcess.GetCoreWorker().RemoveLocalReference(
|
||||
c_object_id)
|
||||
|
||||
def serialize_and_promote_object_ref(self, ObjectRef object_ref):
|
||||
cdef:
|
||||
|
||||
@@ -250,6 +250,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
(void(
|
||||
const c_vector[c_string]&,
|
||||
CWorkerType) nogil) delete_spilled_objects
|
||||
(void(const CRayObject&) nogil) unhandled_exception_handler
|
||||
(void(c_string *stack_out) nogil) get_lang_stack
|
||||
c_bool ref_counting_enabled
|
||||
c_bool is_local_mode
|
||||
|
||||
@@ -20,6 +20,52 @@ from ray.test_utils import (wait_for_condition, SignalActor, init_error_pubsub,
|
||||
get_error_message, Semaphore)
|
||||
|
||||
|
||||
def test_unhandled_errors(ray_start_regular):
|
||||
@ray.remote
|
||||
def f():
|
||||
raise ValueError()
|
||||
|
||||
@ray.remote
|
||||
class Actor:
|
||||
def f(self):
|
||||
raise ValueError()
|
||||
|
||||
a = Actor.remote()
|
||||
num_exceptions = 0
|
||||
|
||||
def interceptor(e):
|
||||
nonlocal num_exceptions
|
||||
num_exceptions += 1
|
||||
|
||||
# Test we report unhandled exceptions.
|
||||
ray.worker._unhandled_error_handler = interceptor
|
||||
x1 = f.remote()
|
||||
x2 = a.f.remote()
|
||||
del x1
|
||||
del x2
|
||||
wait_for_condition(lambda: num_exceptions == 2)
|
||||
|
||||
# Test we don't report handled exceptions.
|
||||
x1 = f.remote()
|
||||
x2 = a.f.remote()
|
||||
with pytest.raises(ray.exceptions.RayError) as err: # noqa
|
||||
ray.get([x1, x2])
|
||||
del x1
|
||||
del x2
|
||||
time.sleep(1)
|
||||
assert num_exceptions == 2, num_exceptions
|
||||
|
||||
# Test suppression with env var works.
|
||||
try:
|
||||
os.environ["RAY_IGNORE_UNHANDLED_ERRORS"] = "1"
|
||||
x1 = f.remote()
|
||||
del x1
|
||||
time.sleep(1)
|
||||
assert num_exceptions == 2, num_exceptions
|
||||
finally:
|
||||
del os.environ["RAY_IGNORE_UNHANDLED_ERRORS"]
|
||||
|
||||
|
||||
def test_failed_task(ray_start_regular, error_pubsub):
|
||||
@ray.remote
|
||||
def throw_exception_fct1():
|
||||
|
||||
+19
-60
@@ -9,7 +9,6 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import redis
|
||||
from six.moves import queue
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@@ -69,6 +68,12 @@ ERROR_KEY_PREFIX = b"Error:"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Visible for testing.
|
||||
def _unhandled_error_handler(e: Exception):
|
||||
logger.error("Unhandled error (suppress with "
|
||||
"RAY_IGNORE_UNHANDLED_ERRORS=1): {}".format(e))
|
||||
|
||||
|
||||
class Worker:
|
||||
"""A class used to define the control flow of a worker process.
|
||||
|
||||
@@ -277,6 +282,14 @@ class Worker:
|
||||
self.core_worker.put_serialized_object(
|
||||
serialized_value, object_ref=object_ref))
|
||||
|
||||
def raise_errors(self, data_metadata_pairs, object_refs):
|
||||
context = self.get_serialization_context()
|
||||
out = context.deserialize_objects(data_metadata_pairs, object_refs)
|
||||
if "RAY_IGNORE_UNHANDLED_ERRORS" in os.environ:
|
||||
return
|
||||
for e in out:
|
||||
_unhandled_error_handler(e)
|
||||
|
||||
def deserialize_objects(self, data_metadata_pairs, object_refs):
|
||||
context = self.get_serialization_context()
|
||||
return context.deserialize_objects(data_metadata_pairs, object_refs)
|
||||
@@ -863,13 +876,6 @@ def custom_excepthook(type, value, tb):
|
||||
|
||||
sys.excepthook = custom_excepthook
|
||||
|
||||
# The last time we raised a TaskError in this process. We use this value to
|
||||
# suppress redundant error messages pushed from the workers.
|
||||
last_task_error_raise_time = 0
|
||||
|
||||
# The max amount of seconds to wait before printing out an uncaught error.
|
||||
UNCAUGHT_ERROR_GRACE_PERIOD = 5
|
||||
|
||||
|
||||
def print_logs(redis_client, threads_stopped, job_id):
|
||||
"""Prints log messages from workers on all of the nodes.
|
||||
@@ -1020,42 +1026,7 @@ def print_worker_logs(data: Dict[str, str], print_file: Any):
|
||||
file=print_file)
|
||||
|
||||
|
||||
def print_error_messages_raylet(task_error_queue, threads_stopped):
|
||||
"""Prints message received in the given output queue.
|
||||
|
||||
This checks periodically if any un-raised errors occurred in the
|
||||
background.
|
||||
|
||||
Args:
|
||||
task_error_queue (queue.Queue): A queue used to receive errors from the
|
||||
thread that listens to Redis.
|
||||
threads_stopped (threading.Event): A threading event used to signal to
|
||||
the thread that it should exit.
|
||||
"""
|
||||
|
||||
while True:
|
||||
# Exit if we received a signal that we should stop.
|
||||
if threads_stopped.is_set():
|
||||
return
|
||||
|
||||
try:
|
||||
error, t = task_error_queue.get(block=False)
|
||||
except queue.Empty:
|
||||
threads_stopped.wait(timeout=0.01)
|
||||
continue
|
||||
# Delay errors a little bit of time to attempt to suppress redundant
|
||||
# messages originating from the worker.
|
||||
while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time():
|
||||
threads_stopped.wait(timeout=1)
|
||||
if threads_stopped.is_set():
|
||||
break
|
||||
if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD:
|
||||
logger.debug(f"Suppressing error from worker: {error}")
|
||||
else:
|
||||
logger.error(f"Possible unhandled error from worker: {error}")
|
||||
|
||||
|
||||
def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
|
||||
def listen_error_messages_raylet(worker, threads_stopped):
|
||||
"""Listen to error messages in the background on the driver.
|
||||
|
||||
This runs in a separate thread on the driver and pushes (error, time)
|
||||
@@ -1063,8 +1034,6 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
|
||||
|
||||
Args:
|
||||
worker: The worker class that this thread belongs to.
|
||||
task_error_queue (queue.Queue): A queue used to communicate with the
|
||||
thread that prints the errors found by this thread.
|
||||
threads_stopped (threading.Event): A threading event used to signal to
|
||||
the thread that it should exit.
|
||||
"""
|
||||
@@ -1103,8 +1072,9 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
|
||||
|
||||
error_message = error_data.error_message
|
||||
if (error_data.type == ray_constants.TASK_PUSH_ERROR):
|
||||
# Delay it a bit to see if we can suppress it
|
||||
task_error_queue.put((error_message, time.time()))
|
||||
# TODO(ekl) remove task push errors entirely now that we have
|
||||
# the separate unhandled exception handler.
|
||||
pass
|
||||
else:
|
||||
logger.warning(error_message)
|
||||
except (OSError, redis.exceptions.ConnectionError) as e:
|
||||
@@ -1267,19 +1237,12 @@ def connect(node,
|
||||
# temporarily using this implementation which constantly queries the
|
||||
# scheduler for new error messages.
|
||||
if mode == SCRIPT_MODE:
|
||||
q = queue.Queue()
|
||||
worker.listener_thread = threading.Thread(
|
||||
target=listen_error_messages_raylet,
|
||||
name="ray_listen_error_messages",
|
||||
args=(worker, q, worker.threads_stopped))
|
||||
worker.printer_thread = threading.Thread(
|
||||
target=print_error_messages_raylet,
|
||||
name="ray_print_error_messages",
|
||||
args=(q, worker.threads_stopped))
|
||||
args=(worker, worker.threads_stopped))
|
||||
worker.listener_thread.daemon = True
|
||||
worker.listener_thread.start()
|
||||
worker.printer_thread.daemon = True
|
||||
worker.printer_thread.start()
|
||||
if log_to_driver:
|
||||
global_worker_stdstream_dispatcher.add_handler(
|
||||
"ray_print_logs", print_to_stdstream)
|
||||
@@ -1332,8 +1295,6 @@ def disconnect(exiting_interpreter=False):
|
||||
worker.import_thread.join_import_thread()
|
||||
if hasattr(worker, "listener_thread"):
|
||||
worker.listener_thread.join()
|
||||
if hasattr(worker, "printer_thread"):
|
||||
worker.printer_thread.join()
|
||||
if hasattr(worker, "logger_thread"):
|
||||
worker.logger_thread.join()
|
||||
worker.threads_stopped.clear()
|
||||
@@ -1445,13 +1406,11 @@ def get(object_refs, *, timeout=None):
|
||||
raise ValueError("'object_refs' must either be an object ref "
|
||||
"or a list of object refs.")
|
||||
|
||||
global last_task_error_raise_time
|
||||
# TODO(ujvl): Consider how to allow user to retrieve the ready objects.
|
||||
values, debugger_breakpoint = worker.get_objects(
|
||||
object_refs, timeout=timeout)
|
||||
for i, value in enumerate(values):
|
||||
if isinstance(value, RayError):
|
||||
last_task_error_raise_time = time.time()
|
||||
if isinstance(value, ray.exceptions.ObjectLostError):
|
||||
worker.core_worker.dump_object_store_memory_usage()
|
||||
if isinstance(value, RayTaskError):
|
||||
|
||||
@@ -92,12 +92,20 @@ class RayObject {
|
||||
/// large to return directly as part of a gRPC response).
|
||||
bool IsInPlasmaError() const;
|
||||
|
||||
/// Mark this object as accessed before.
|
||||
void SetAccessed() { accessed_ = true; };
|
||||
|
||||
/// Check if this object was accessed before.
|
||||
bool WasAccessed() const { return accessed_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Buffer> data_;
|
||||
std::shared_ptr<Buffer> metadata_;
|
||||
const std::vector<ObjectID> nested_ids_;
|
||||
/// Whether this class holds a data copy.
|
||||
bool has_data_copy_;
|
||||
/// Whether this object was accessed.
|
||||
bool accessed_ = false;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -422,7 +422,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_
|
||||
return Status::OK();
|
||||
},
|
||||
options_.ref_counting_enabled ? reference_counter_ : nullptr, local_raylet_client_,
|
||||
options_.check_signals));
|
||||
options_.check_signals, options_.unhandled_exception_handler));
|
||||
|
||||
auto check_node_alive_fn = [this](const NodeID &node_id) {
|
||||
auto node = gcs_client_->Nodes().Get(node_id);
|
||||
|
||||
@@ -82,6 +82,7 @@ struct CoreWorkerOptions {
|
||||
spill_objects(nullptr),
|
||||
restore_spilled_objects(nullptr),
|
||||
delete_spilled_objects(nullptr),
|
||||
unhandled_exception_handler(nullptr),
|
||||
get_lang_stack(nullptr),
|
||||
kill_main(nullptr),
|
||||
ref_counting_enabled(false),
|
||||
@@ -146,6 +147,8 @@ struct CoreWorkerOptions {
|
||||
/// Application-language callback to delete objects from external storage.
|
||||
std::function<void(const std::vector<std::string> &, rpc::WorkerType)>
|
||||
delete_spilled_objects;
|
||||
/// Function to call on error objects never retrieved.
|
||||
std::function<void(const RayObject &error)> unhandled_exception_handler;
|
||||
/// Language worker callback to get the current call stack.
|
||||
std::function<void(std::string *)> get_lang_stack;
|
||||
// Function that tries to interrupt the currently running Python thread.
|
||||
|
||||
@@ -93,6 +93,7 @@ void GetRequest::Set(const ObjectID &object_id, std::shared_ptr<RayObject> objec
|
||||
if (is_ready_) {
|
||||
return; // We have already hit the number of objects to return limit.
|
||||
}
|
||||
object->SetAccessed();
|
||||
objects_.emplace(object_id, object);
|
||||
if (objects_.size() == num_objects_ ||
|
||||
(abort_if_any_object_is_exception_ && object->IsException() &&
|
||||
@@ -106,6 +107,7 @@ std::shared_ptr<RayObject> GetRequest::Get(const ObjectID &object_id) const {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
auto iter = objects_.find(object_id);
|
||||
if (iter != objects_.end()) {
|
||||
iter->second->SetAccessed();
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
@@ -116,11 +118,13 @@ CoreWorkerMemoryStore::CoreWorkerMemoryStore(
|
||||
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma,
|
||||
std::shared_ptr<ReferenceCounter> counter,
|
||||
std::shared_ptr<raylet::RayletClient> raylet_client,
|
||||
std::function<Status()> check_signals)
|
||||
std::function<Status()> check_signals,
|
||||
std::function<void(const RayObject &)> unhandled_exception_handler)
|
||||
: store_in_plasma_(store_in_plasma),
|
||||
ref_counter_(counter),
|
||||
raylet_client_(raylet_client),
|
||||
check_signals_(check_signals) {}
|
||||
check_signals_(check_signals),
|
||||
unhandled_exception_handler_(unhandled_exception_handler) {}
|
||||
|
||||
void CoreWorkerMemoryStore::GetAsync(
|
||||
const ObjectID &object_id, std::function<void(std::shared_ptr<RayObject>)> callback) {
|
||||
@@ -136,6 +140,7 @@ void CoreWorkerMemoryStore::GetAsync(
|
||||
}
|
||||
// It's important for performance to run the callback outside the lock.
|
||||
if (ptr != nullptr) {
|
||||
ptr->SetAccessed();
|
||||
callback(ptr);
|
||||
}
|
||||
}
|
||||
@@ -146,6 +151,7 @@ std::shared_ptr<RayObject> CoreWorkerMemoryStore::GetOrPromoteToPlasma(
|
||||
auto iter = objects_.find(object_id);
|
||||
if (iter != objects_.end()) {
|
||||
auto obj = iter->second;
|
||||
obj->SetAccessed();
|
||||
if (obj->IsInPlasmaError()) {
|
||||
return nullptr;
|
||||
}
|
||||
@@ -210,6 +216,8 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_
|
||||
if (should_add_entry) {
|
||||
// If there is no existing get request, then add the `RayObject` to map.
|
||||
objects_.emplace(object_id, object_entry);
|
||||
} else {
|
||||
OnErase(object_entry);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -223,6 +231,7 @@ bool CoreWorkerMemoryStore::Put(const RayObject &object, const ObjectID &object_
|
||||
|
||||
// It's important for performance to run the callbacks outside the lock.
|
||||
for (const auto &cb : async_callbacks) {
|
||||
object_entry->SetAccessed();
|
||||
cb(object_entry);
|
||||
}
|
||||
|
||||
@@ -257,6 +266,7 @@ Status CoreWorkerMemoryStore::GetImpl(const std::vector<ObjectID> &object_ids,
|
||||
const auto &object_id = object_ids[i];
|
||||
auto iter = objects_.find(object_id);
|
||||
if (iter != objects_.end()) {
|
||||
iter->second->SetAccessed();
|
||||
(*results)[i] = iter->second;
|
||||
if (remove_after_get) {
|
||||
// Note that we cannot remove the object_id from `objects_` now,
|
||||
@@ -426,6 +436,7 @@ void CoreWorkerMemoryStore::Delete(const absl::flat_hash_set<ObjectID> &object_i
|
||||
if (it->second->IsInPlasmaError()) {
|
||||
plasma_ids_to_delete->insert(object_id);
|
||||
} else {
|
||||
OnErase(it->second);
|
||||
objects_.erase(it);
|
||||
}
|
||||
}
|
||||
@@ -435,7 +446,11 @@ void CoreWorkerMemoryStore::Delete(const absl::flat_hash_set<ObjectID> &object_i
|
||||
void CoreWorkerMemoryStore::Delete(const std::vector<ObjectID> &object_ids) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
for (const auto &object_id : object_ids) {
|
||||
objects_.erase(object_id);
|
||||
auto it = objects_.find(object_id);
|
||||
if (it != objects_.end()) {
|
||||
OnErase(it->second);
|
||||
objects_.erase(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -451,6 +466,14 @@ bool CoreWorkerMemoryStore::Contains(const ObjectID &object_id, bool *in_plasma)
|
||||
return false;
|
||||
}
|
||||
|
||||
void CoreWorkerMemoryStore::OnErase(std::shared_ptr<RayObject> obj) {
|
||||
// TODO(ekl) note that this doesn't warn on errors that are stored in plasma.
|
||||
if (obj->IsException() && !obj->IsInPlasmaError() && !obj->WasAccessed() &&
|
||||
unhandled_exception_handler_ != nullptr) {
|
||||
unhandled_exception_handler_(*obj);
|
||||
}
|
||||
}
|
||||
|
||||
MemoryStoreStats CoreWorkerMemoryStore::GetMemoryStoreStatisticalData() {
|
||||
absl::MutexLock lock(&mu_);
|
||||
MemoryStoreStats item;
|
||||
|
||||
@@ -35,7 +35,8 @@ class CoreWorkerMemoryStore {
|
||||
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma = nullptr,
|
||||
std::shared_ptr<ReferenceCounter> counter = nullptr,
|
||||
std::shared_ptr<raylet::RayletClient> raylet_client = nullptr,
|
||||
std::function<Status()> check_signals = nullptr);
|
||||
std::function<Status()> check_signals = nullptr,
|
||||
std::function<void(const RayObject &)> unhandled_exception_handler = nullptr);
|
||||
~CoreWorkerMemoryStore(){};
|
||||
|
||||
/// Put an object with specified ID into object store.
|
||||
@@ -143,6 +144,9 @@ class CoreWorkerMemoryStore {
|
||||
std::vector<std::shared_ptr<RayObject>> *results,
|
||||
bool abort_if_any_object_is_exception);
|
||||
|
||||
/// Called when an object is erased from the store.
|
||||
void OnErase(std::shared_ptr<RayObject> obj);
|
||||
|
||||
/// Optional callback for putting objects into the plasma store.
|
||||
std::function<void(const RayObject &, const ObjectID &)> store_in_plasma_;
|
||||
|
||||
@@ -173,6 +177,9 @@ class CoreWorkerMemoryStore {
|
||||
|
||||
/// Function passed in to be called to check for signals (e.g., Ctrl-C).
|
||||
std::function<Status()> check_signals_;
|
||||
|
||||
/// Function called to report unhandled exceptions.
|
||||
std::function<void(const RayObject &)> unhandled_exception_handler_;
|
||||
};
|
||||
|
||||
} // namespace ray
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
// Copyright 2017 The Ray Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ray/common/test_util.h"
|
||||
|
||||
namespace ray {
|
||||
|
||||
TEST(TestMemoryStore, TestReportUnhandledErrors) {
|
||||
std::vector<std::shared_ptr<RayObject>> results;
|
||||
WorkerContext context(WorkerType::WORKER, WorkerID::FromRandom(), JobID::FromInt(0));
|
||||
int unhandled_count = 0;
|
||||
|
||||
std::shared_ptr<CoreWorkerMemoryStore> provider =
|
||||
std::make_shared<CoreWorkerMemoryStore>(
|
||||
nullptr, nullptr, nullptr, nullptr,
|
||||
[&](const RayObject &obj) { unhandled_count++; });
|
||||
RayObject obj1(rpc::ErrorType::TASK_EXECUTION_EXCEPTION);
|
||||
RayObject obj2(rpc::ErrorType::TASK_EXECUTION_EXCEPTION);
|
||||
auto id1 = ObjectID::FromRandom();
|
||||
auto id2 = ObjectID::FromRandom();
|
||||
|
||||
// Check delete without get.
|
||||
RAY_CHECK(provider->Put(obj1, id1));
|
||||
RAY_CHECK(provider->Put(obj2, id2));
|
||||
ASSERT_EQ(unhandled_count, 0);
|
||||
provider->Delete({id1, id2});
|
||||
ASSERT_EQ(unhandled_count, 2);
|
||||
unhandled_count = 0;
|
||||
|
||||
// Check delete after get.
|
||||
RAY_CHECK(provider->Put(obj1, id1));
|
||||
RAY_CHECK(provider->Put(obj1, id2));
|
||||
provider->Get({id1}, 1, 100, context, false, &results);
|
||||
provider->GetOrPromoteToPlasma(id2);
|
||||
provider->Delete({id1, id2});
|
||||
ASSERT_EQ(unhandled_count, 0);
|
||||
|
||||
// Check delete after async get.
|
||||
provider->GetAsync({id2}, [](std::shared_ptr<RayObject> obj) {});
|
||||
RAY_CHECK(provider->Put(obj1, id1));
|
||||
RAY_CHECK(provider->Put(obj2, id2));
|
||||
provider->GetAsync({id1}, [](std::shared_ptr<RayObject> obj) {});
|
||||
provider->Delete({id1, id2});
|
||||
ASSERT_EQ(unhandled_count, 0);
|
||||
}
|
||||
|
||||
} // namespace ray
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
::testing::InitGoogleTest(&argc, argv);
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
Reference in New Issue
Block a user