mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 04:07:01 +08:00
Warn the user when a nondeterministic task is detected. (#339)
* WARN instead of FATAL for object hash mismatches, push error to driver * Document the callback signature for object_table_add/remove * Error table * Wait for all errors in python test * Fix doc * Fix state test
This commit is contained in:
committed by
Robert Nishihara
parent
0b8d279ef2
commit
da06b4db82
@@ -65,14 +65,22 @@ class TestGlobalStateStore(unittest.TestCase):
|
||||
# Check that Redis returns an error when RAY.OBJECT_TABLE_ADD adds an object
|
||||
# ID that is already present with a different hash.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1"})
|
||||
with self.assertRaises(redis.ResponseError):
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id2")
|
||||
# Check that the second manager was added, even though the hash was
|
||||
# mismatched.
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
# Check that it is fine if we add the same object ID multiple times with the
|
||||
# same hash.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2, "hash1", "manager_id2")
|
||||
# most recent hash.
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id1")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash2", "manager_id2")
|
||||
self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 2, "hash2", "manager_id2")
|
||||
response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1")
|
||||
self.assertEqual(set(response), {b"manager_id1", b"manager_id2"})
|
||||
|
||||
def testObjectTableAddAndLookup(self):
|
||||
# Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been
|
||||
|
||||
@@ -721,50 +721,6 @@ class TestPlasmaManager(unittest.TestCase):
|
||||
assert_get_object_equal(self, self.client1, self.client2, object_id2,
|
||||
memory_buffer=memory_buffer2, metadata=metadata2)
|
||||
|
||||
def test_illegal_put(self):
|
||||
"""
|
||||
Test doing a put at the same object ID, but with different object data. The
|
||||
first put should succeed. The second put should cause the plasma manager to
|
||||
exit with a fatal error.
|
||||
"""
|
||||
if USE_VALGRIND:
|
||||
# Don't run this test when we are using valgrind because when processes
|
||||
# die without freeing up their state, valgrind complains.
|
||||
return
|
||||
# Create and seal the first object.
|
||||
length = 1000
|
||||
object_id = random_object_id()
|
||||
memory_buffer1 = self.client1.create(object_id, length)
|
||||
for i in range(length):
|
||||
memory_buffer1[i] = chr(i % 256)
|
||||
self.client1.seal(object_id)
|
||||
# Create and seal the second object. It has all the same data as the first
|
||||
# object, with one bit flipped.
|
||||
memory_buffer2 = self.client2.create(object_id, length)
|
||||
for i in range(length):
|
||||
j = i
|
||||
if j == 0:
|
||||
j += 1
|
||||
memory_buffer2[i] = chr(j % 256)
|
||||
self.client2.seal(object_id)
|
||||
# Make sure that one of the plasma managers exited (the second one to call
|
||||
# RAY.OBJECT_TABLE_ADD should have exited). In the vast majority of cases,
|
||||
# this should be p5. However, on Travis, it is frequently p4.
|
||||
time_left = 100
|
||||
while time_left > 0:
|
||||
if self.p5.poll() != None:
|
||||
self.processes_to_kill.remove(self.p5)
|
||||
break
|
||||
if self.p4.poll() != None:
|
||||
self.processes_to_kill.remove(self.p4)
|
||||
break
|
||||
time_left -= 0.1
|
||||
time.sleep(0.1)
|
||||
|
||||
print("Time waiting for plasma manager to fail = {:.2}".format(100 - time_left))
|
||||
# Check that exactly one of the plasma managers has died.
|
||||
self.assertEqual([self.p5.poll(), self.p4.poll()].count(None), 1)
|
||||
|
||||
def test_illegal_functionality(self):
|
||||
# Create an object id string.
|
||||
object_id = random_object_id()
|
||||
|
||||
@@ -49,6 +49,10 @@ NIL_ACTOR_ID = 20 * b"\xff"
|
||||
# fetch the object again.
|
||||
GET_TIMEOUT_MILLISECONDS = 1000
|
||||
|
||||
# This must be kept in sync with the `error_types` array in
|
||||
# common/state/error_table.h.
|
||||
OBJECT_HASH_MISMATCH_ERROR_TYPE = b"object_hash_mismatch"
|
||||
|
||||
def random_string():
|
||||
return np.random.bytes(20)
|
||||
|
||||
@@ -677,6 +681,12 @@ def error_info(worker=global_worker):
|
||||
for error_key in error_keys:
|
||||
if error_applies_to_driver(error_key, worker=worker):
|
||||
error_contents = worker.redis_client.hgetall(error_key)
|
||||
# If the error is an object hash mismatch, look up the function name for
|
||||
# the nondeterministic task.
|
||||
if error_contents[b"type"] == OBJECT_HASH_MISMATCH_ERROR_TYPE:
|
||||
function_id = error_contents[b"data"]
|
||||
function_name = worker.redis_client.hget("RemoteFunction:{}".format(function_id), "name")
|
||||
error_contents[b"data"] = function_name
|
||||
errors.append(error_contents)
|
||||
|
||||
return errors
|
||||
|
||||
@@ -46,6 +46,7 @@ add_library(common STATIC
|
||||
state/db_client_table.cc
|
||||
state/actor_notification_table.cc
|
||||
state/local_scheduler_table.cc
|
||||
state/error_table.cc
|
||||
thirdparty/ae/ae.c
|
||||
thirdparty/sha256.c)
|
||||
|
||||
|
||||
@@ -388,9 +388,9 @@ bool PublishObjectNotification(RedisModuleCtx *ctx,
|
||||
* @param hash_string A string which is a hash of the object.
|
||||
* @param manager A string which represents the manager ID of the plasma manager
|
||||
* that has the object.
|
||||
* @return OK if the operation was successful and an error with string
|
||||
* "hash mismatch" if the same object_id is already present with a
|
||||
* different hash value.
|
||||
* @return OK if the operation was successful. If the same object_id is already
|
||||
* present with a different hash value, the entry is still added, but
|
||||
* an error with string "hash mismatch" is returned.
|
||||
*/
|
||||
int ObjectTableAdd_RedisCommand(RedisModuleCtx *ctx,
|
||||
RedisModuleString **argv,
|
||||
@@ -416,6 +416,7 @@ int ObjectTableAdd_RedisCommand(RedisModuleCtx *ctx,
|
||||
REDISMODULE_READ | REDISMODULE_WRITE);
|
||||
|
||||
/* Check if this object was already registered and if the hashes agree. */
|
||||
bool hash_mismatch = false;
|
||||
if (RedisModule_KeyType(key) != REDISMODULE_KEYTYPE_EMPTY) {
|
||||
RedisModuleString *existing_hash;
|
||||
RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "hash", &existing_hash,
|
||||
@@ -423,11 +424,9 @@ int ObjectTableAdd_RedisCommand(RedisModuleCtx *ctx,
|
||||
/* The existing hash may be NULL even if the key is present because a call
|
||||
* to RAY.RESULT_TABLE_ADD may have already created the key. */
|
||||
if (existing_hash != NULL) {
|
||||
if (RedisModule_StringCompare(existing_hash, new_hash) != 0) {
|
||||
RedisModule_CloseKey(key);
|
||||
RedisModule_FreeString(ctx, existing_hash);
|
||||
return RedisModule_ReplyWithError(ctx, "hash mismatch");
|
||||
}
|
||||
/* Check whether the new hash value matches the old one. If not, we will
|
||||
* later return the "hash mismatch" error. */
|
||||
hash_mismatch = (RedisModule_StringCompare(existing_hash, new_hash) != 0);
|
||||
RedisModule_FreeString(ctx, existing_hash);
|
||||
}
|
||||
}
|
||||
@@ -493,8 +492,12 @@ int ObjectTableAdd_RedisCommand(RedisModuleCtx *ctx,
|
||||
}
|
||||
|
||||
RedisModule_CloseKey(table_key);
|
||||
RedisModule_ReplyWithSimpleString(ctx, "OK");
|
||||
return REDISMODULE_OK;
|
||||
if (hash_mismatch) {
|
||||
return RedisModule_ReplyWithError(ctx, "hash mismatch");
|
||||
} else {
|
||||
RedisModule_ReplyWithSimpleString(ctx, "OK");
|
||||
return REDISMODULE_OK;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
#include "error_table.h"
|
||||
#include "redis.h"
|
||||
|
||||
void push_error(DBHandle *db_handle,
|
||||
DBClientID driver_id,
|
||||
int error_index,
|
||||
size_t data_length,
|
||||
unsigned char *data) {
|
||||
CHECK(error_index >= 0 && error_index < MAX_ERROR_INDEX);
|
||||
/* Allocate a struct to hold the error information. */
|
||||
ErrorInfo *info = (ErrorInfo *) malloc(sizeof(ErrorInfo) + data_length);
|
||||
info->driver_id = driver_id;
|
||||
info->error_index = error_index;
|
||||
info->data_length = data_length;
|
||||
memcpy(info->data, data, data_length);
|
||||
/* Generate a random key to identify this error message. */
|
||||
CHECK(sizeof(info->error_key) >= UNIQUE_ID_SIZE);
|
||||
UniqueID error_key = globally_unique_id();
|
||||
memcpy(info->error_key, error_key.id, sizeof(info->error_key));
|
||||
|
||||
init_table_callback(db_handle, NIL_ID, __func__, info, NULL, NULL,
|
||||
redis_push_error, NULL);
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
#ifndef ERROR_TABLE_H
|
||||
#define ERROR_TABLE_H
|
||||
|
||||
#include "db.h"
|
||||
#include "table.h"
|
||||
|
||||
typedef struct {
|
||||
DBClientID driver_id;
|
||||
unsigned char error_key[20];
|
||||
int error_index;
|
||||
size_t data_length;
|
||||
unsigned char data[0];
|
||||
} ErrorInfo;
|
||||
|
||||
/** An error_index may be used as an index into error_types and
|
||||
* error_messages. */
|
||||
typedef enum {
|
||||
/** An object was added with a different hash from the existing
|
||||
* one. */
|
||||
OBJECT_HASH_MISMATCH_ERROR_INDEX = 0,
|
||||
/** The total number of error types. */
|
||||
MAX_ERROR_INDEX
|
||||
} error_index;
|
||||
|
||||
/** Information about the error to be displayed to the user. */
|
||||
static const char *error_types[] = {"object_hash_mismatch"};
|
||||
static const char *error_messages[] = {
|
||||
"A nondeterministic task was reexecuted."};
|
||||
|
||||
/**
|
||||
* Push an error to the given Python driver.
|
||||
*
|
||||
* @param db_handle Database handle.
|
||||
* @param driver_id The ID of the Python driver to push the error
|
||||
* to.
|
||||
* @param error_index The error information at this index in
|
||||
* error_types and error_messages will be included in the
|
||||
* error pushed to the driver.
|
||||
* @param data_length The length of the custom data to be included
|
||||
* in the error.
|
||||
* @param data The custom data to be included in the error.
|
||||
* @return Void.
|
||||
*/
|
||||
void push_error(DBHandle *db_handle,
|
||||
DBClientID driver_id,
|
||||
int error_index,
|
||||
size_t data_length,
|
||||
unsigned char *data);
|
||||
|
||||
#endif
|
||||
@@ -49,8 +49,18 @@ void object_table_lookup(DBHandle *db_handle,
|
||||
* ==== Add object call and callback ====
|
||||
*/
|
||||
|
||||
/* Callback called when the object add/remove operation completes. */
|
||||
/**
|
||||
* Callback called when the object add/remove operation completes.
|
||||
*
|
||||
* @param object_id The ID of the object that was added or removed.
|
||||
* @param success Whether the operation was successful or not. If this is false
|
||||
* and the operation was an addition, the object was added, but there
|
||||
* was a hash mismatch.
|
||||
* @param user_context The user context that was passed into the add/remove
|
||||
* call.
|
||||
*/
|
||||
typedef void (*object_table_done_callback)(ObjectID object_id,
|
||||
bool success,
|
||||
void *user_context);
|
||||
|
||||
/**
|
||||
|
||||
@@ -22,6 +22,7 @@ extern "C" {
|
||||
#include "object_info.h"
|
||||
#include "task.h"
|
||||
#include "task_table.h"
|
||||
#include "error_table.h"
|
||||
#include "event_loop.h"
|
||||
#include "redis.h"
|
||||
#include "io.h"
|
||||
@@ -217,21 +218,23 @@ void redis_object_table_add_callback(redisAsyncContext *c,
|
||||
|
||||
/* Do some minimal checking. */
|
||||
redisReply *reply = (redisReply *) r;
|
||||
if (strcmp(reply->str, "hash mismatch") == 0) {
|
||||
bool success = (strcmp(reply->str, "hash mismatch") != 0);
|
||||
if (!success) {
|
||||
/* If our object hash doesn't match the one recorded in the table, report
|
||||
* the error back to the user and exit immediately. */
|
||||
LOG_FATAL(
|
||||
LOG_WARN(
|
||||
"Found objects with different value but same object ID, most likely "
|
||||
"because a nondeterministic task was executed twice, either for "
|
||||
"reconstruction or for speculation.");
|
||||
} else {
|
||||
CHECK(reply->type != REDIS_REPLY_ERROR);
|
||||
CHECK(strcmp(reply->str, "OK") == 0);
|
||||
}
|
||||
CHECK(reply->type != REDIS_REPLY_ERROR);
|
||||
CHECK(strcmp(reply->str, "OK") == 0);
|
||||
/* Call the done callback if there is one. */
|
||||
if (callback_data->done_callback != NULL) {
|
||||
object_table_done_callback done_callback =
|
||||
(object_table_done_callback) callback_data->done_callback;
|
||||
done_callback(callback_data->id, callback_data->user_context);
|
||||
done_callback(callback_data->id, success, callback_data->user_context);
|
||||
}
|
||||
/* Clean up the timer and callback. */
|
||||
destroy_timer_callback(db->loop, callback_data);
|
||||
@@ -274,7 +277,7 @@ void redis_object_table_remove_callback(redisAsyncContext *c,
|
||||
if (callback_data->done_callback != NULL) {
|
||||
object_table_done_callback done_callback =
|
||||
(object_table_done_callback) callback_data->done_callback;
|
||||
done_callback(callback_data->id, callback_data->user_context);
|
||||
done_callback(callback_data->id, true, callback_data->user_context);
|
||||
}
|
||||
/* Clean up the timer and callback. */
|
||||
destroy_timer_callback(db->loop, callback_data);
|
||||
@@ -1275,6 +1278,58 @@ void redis_object_info_subscribe(TableCallbackData *callback_data) {
|
||||
}
|
||||
}
|
||||
|
||||
void redis_push_error_rpush_callback(redisAsyncContext *c,
|
||||
void *r,
|
||||
void *privdata) {
|
||||
REDIS_CALLBACK_HEADER(db, callback_data, r);
|
||||
redisReply *reply = (redisReply *) r;
|
||||
/* The reply should be the length of the errors list after our RPUSH. */
|
||||
CHECK(reply->type == REDIS_REPLY_INTEGER);
|
||||
destroy_timer_callback(db->loop, callback_data);
|
||||
}
|
||||
|
||||
void redis_push_error_hmset_callback(redisAsyncContext *c,
|
||||
void *r,
|
||||
void *privdata) {
|
||||
REDIS_CALLBACK_HEADER(db, callback_data, r);
|
||||
redisReply *reply = (redisReply *) r;
|
||||
|
||||
/* Make sure we were able to add the error information. */
|
||||
CHECK(reply->type != REDIS_REPLY_ERROR);
|
||||
CHECK(strcmp(reply->str, "OK") == 0);
|
||||
|
||||
/* Add the error to this driver's list of errors. */
|
||||
ErrorInfo *info = (ErrorInfo *) callback_data->data;
|
||||
int status = redisAsyncCommand(db->context, redis_push_error_rpush_callback,
|
||||
(void *) callback_data->timer_id,
|
||||
"RPUSH ErrorKeys Error:%b:%b",
|
||||
info->driver_id.id, sizeof(info->driver_id.id),
|
||||
info->error_key, sizeof(info->error_key));
|
||||
if ((status == REDIS_ERR) || db->sub_context->err) {
|
||||
LOG_REDIS_DEBUG(db->sub_context, "error in redis_push_error rpush");
|
||||
}
|
||||
}
|
||||
|
||||
void redis_push_error(TableCallbackData *callback_data) {
|
||||
DBHandle *db = callback_data->db_handle;
|
||||
ErrorInfo *info = (ErrorInfo *) callback_data->data;
|
||||
CHECK(info->error_index < MAX_ERROR_INDEX && info->error_index >= 0);
|
||||
/* Look up the error type. */
|
||||
const char *error_type = error_types[info->error_index];
|
||||
const char *error_message = error_messages[info->error_index];
|
||||
|
||||
/* Set the error information. */
|
||||
int status = redisAsyncCommand(
|
||||
db->context, redis_push_error_hmset_callback,
|
||||
(void *) callback_data->timer_id,
|
||||
"HMSET Error:%b:%b type %s message %s data %b", info->driver_id.id,
|
||||
sizeof(info->driver_id.id), info->error_key, sizeof(info->error_key),
|
||||
error_type, error_message, info->data, info->data_length);
|
||||
if ((status == REDIS_ERR) || db->sub_context->err) {
|
||||
LOG_REDIS_DEBUG(db->sub_context, "error in redis_push_error hmset");
|
||||
}
|
||||
}
|
||||
|
||||
DBClientID get_db_client_id(DBHandle *db) {
|
||||
CHECK(db != NULL);
|
||||
return db->client;
|
||||
|
||||
@@ -264,4 +264,6 @@ void redis_actor_notification_table_subscribe(TableCallbackData *callback_data);
|
||||
|
||||
void redis_object_info_subscribe(TableCallbackData *callback_data);
|
||||
|
||||
void redis_push_error(TableCallbackData *callback_data);
|
||||
|
||||
#endif /* REDIS_H */
|
||||
|
||||
@@ -55,7 +55,7 @@ void lookup_done_callback(ObjectID object_id,
|
||||
}
|
||||
|
||||
/* Entry added to database successfully. */
|
||||
void add_done_callback(ObjectID object_id, void *user_context) {}
|
||||
void add_done_callback(ObjectID object_id, bool success, void *user_context) {}
|
||||
|
||||
/* Test if we got a timeout callback if we couldn't connect database. */
|
||||
void timeout_callback(ObjectID object_id, void *context, void *user_data) {
|
||||
|
||||
@@ -172,7 +172,7 @@ TEST lookup_timeout_test(void) {
|
||||
const char *add_timeout_context = "add_timeout";
|
||||
int add_failed = 0;
|
||||
|
||||
void add_done_callback(ObjectID object_id, void *user_context) {
|
||||
void add_done_callback(ObjectID object_id, bool success, void *user_context) {
|
||||
/* The done callback should not be called. */
|
||||
CHECK(0);
|
||||
}
|
||||
@@ -305,7 +305,8 @@ void add_lookup_done_callback(ObjectID object_id,
|
||||
lookup_retry_succeeded = 1;
|
||||
}
|
||||
|
||||
void add_lookup_callback(ObjectID object_id, void *user_context) {
|
||||
void add_lookup_callback(ObjectID object_id, bool success, void *user_context) {
|
||||
CHECK(success);
|
||||
DBHandle *db = (DBHandle *) user_context;
|
||||
RetryInfo retry = {
|
||||
.num_retries = 5,
|
||||
@@ -353,7 +354,10 @@ void add_remove_lookup_done_callback(ObjectID object_id,
|
||||
lookup_retry_succeeded = 1;
|
||||
}
|
||||
|
||||
void add_remove_lookup_callback(ObjectID object_id, void *user_context) {
|
||||
void add_remove_lookup_callback(ObjectID object_id,
|
||||
bool success,
|
||||
void *user_context) {
|
||||
CHECK(success);
|
||||
DBHandle *db = (DBHandle *) user_context;
|
||||
RetryInfo retry = {
|
||||
.num_retries = 5,
|
||||
@@ -364,7 +368,8 @@ void add_remove_lookup_callback(ObjectID object_id, void *user_context) {
|
||||
(void *) lookup_retry_context);
|
||||
}
|
||||
|
||||
void add_remove_callback(ObjectID object_id, void *user_context) {
|
||||
void add_remove_callback(ObjectID object_id, bool success, void *user_context) {
|
||||
CHECK(success);
|
||||
DBHandle *db = (DBHandle *) user_context;
|
||||
RetryInfo retry = {
|
||||
.num_retries = 5,
|
||||
@@ -482,7 +487,9 @@ void add_late_fail_callback(UniqueID id, void *user_context, void *user_data) {
|
||||
add_late_failed = 1;
|
||||
}
|
||||
|
||||
void add_late_done_callback(ObjectID object_id, void *user_context) {
|
||||
void add_late_done_callback(ObjectID object_id,
|
||||
bool success,
|
||||
void *user_context) {
|
||||
/* This function should never be called. */
|
||||
CHECK(0);
|
||||
}
|
||||
|
||||
@@ -351,7 +351,9 @@ TaskSpec *object_reconstruction_suppression_spec;
|
||||
int64_t object_reconstruction_suppression_size;
|
||||
|
||||
void object_reconstruction_suppression_callback(ObjectID object_id,
|
||||
bool success,
|
||||
void *user_context) {
|
||||
CHECK(success);
|
||||
/* Submit the task after adding the object to the object table. */
|
||||
LocalSchedulerConnection *worker = (LocalSchedulerConnection *) user_context;
|
||||
local_scheduler_submit(worker, object_reconstruction_suppression_spec,
|
||||
|
||||
@@ -35,6 +35,8 @@
|
||||
#include "plasma_manager.h"
|
||||
#include "state/db.h"
|
||||
#include "state/object_table.h"
|
||||
#include "state/error_table.h"
|
||||
#include "state/task_table.h"
|
||||
|
||||
/**
|
||||
* Process either the fetch or the status request.
|
||||
@@ -1270,6 +1272,44 @@ void process_delete_object_notification(PlasmaManagerState *state,
|
||||
* up-to-date. */
|
||||
}
|
||||
|
||||
void log_object_hash_mismatch_error_task_callback(Task *task,
|
||||
void *user_context) {
|
||||
CHECK(task != NULL);
|
||||
PlasmaManagerState *state = (PlasmaManagerState *) user_context;
|
||||
TaskSpec *spec = Task_task_spec(task);
|
||||
FunctionID function = TaskSpec_function(spec);
|
||||
/* Push the error to the Python driver that caused the nondeterministic task
|
||||
* to be submitted. */
|
||||
push_error(state->db, TaskSpec_driver_id(spec),
|
||||
OBJECT_HASH_MISMATCH_ERROR_INDEX, sizeof(function), function.id);
|
||||
}
|
||||
|
||||
void log_object_hash_mismatch_error_result_callback(ObjectID object_id,
|
||||
TaskID task_id,
|
||||
void *user_context) {
|
||||
CHECK(!IS_NIL_ID(task_id));
|
||||
PlasmaManagerState *state = (PlasmaManagerState *) user_context;
|
||||
/* Get the specification for the nondeterministic task. */
|
||||
task_table_get_task(state->db, task_id, NULL,
|
||||
log_object_hash_mismatch_error_task_callback, state);
|
||||
}
|
||||
|
||||
void log_object_hash_mismatch_error_object_callback(ObjectID object_id,
|
||||
bool success,
|
||||
void *user_context) {
|
||||
if (success) {
|
||||
/* The object was added successfully. */
|
||||
return;
|
||||
}
|
||||
|
||||
/* The object was added, but there was an object hash mismatch. In this case,
|
||||
* look up the task that created the object so we can notify the Python
|
||||
* driver that the task is nondeterministic. */
|
||||
PlasmaManagerState *state = (PlasmaManagerState *) user_context;
|
||||
result_table_lookup(state->db, object_id, NULL,
|
||||
log_object_hash_mismatch_error_result_callback, state);
|
||||
}
|
||||
|
||||
void process_add_object_notification(PlasmaManagerState *state,
|
||||
ObjectInfo object_info) {
|
||||
ObjectID obj_id = object_info.obj_id;
|
||||
@@ -1280,11 +1320,10 @@ void process_add_object_notification(PlasmaManagerState *state,
|
||||
|
||||
/* Add this object to the (redis) object table. */
|
||||
if (state->db) {
|
||||
/* TODO(swang): Log the error if we fail to add the object, and possibly
|
||||
* retry later? */
|
||||
object_table_add(state->db, obj_id,
|
||||
object_info.data_size + object_info.metadata_size,
|
||||
object_info.digest, NULL, NULL, NULL);
|
||||
object_table_add(
|
||||
state->db, obj_id, object_info.data_size + object_info.metadata_size,
|
||||
object_info.digest, NULL,
|
||||
log_object_hash_mismatch_error_object_callback, (void *) state);
|
||||
}
|
||||
|
||||
/* If we were trying to fetch this object, finish up the fetch request. */
|
||||
|
||||
@@ -294,6 +294,65 @@ class ReconstructionTests(unittest.TestCase):
|
||||
value = ray.get(args[i])
|
||||
self.assertEqual(value[0], i)
|
||||
|
||||
def testNondeterministicTask(self):
|
||||
# Define the size of one task's return argument so that the combined sum of
|
||||
# all objects' sizes is at least twice the plasma stores' combined allotted
|
||||
# memory.
|
||||
num_objects = 1000
|
||||
size = self.plasma_store_memory * 2 // (num_objects * 8)
|
||||
|
||||
# Define a nondeterministic remote task with no dependencies, which returns
|
||||
# a random numpy array of the given size. This task should produce an error
|
||||
# on the driver if it is ever reexecuted.
|
||||
@ray.remote
|
||||
def foo(i, size):
|
||||
array = np.random.rand(size)
|
||||
array[0] = i
|
||||
return array
|
||||
|
||||
# Define a deterministic remote task with no dependencies, which returns a
|
||||
# numpy array of zeros of the given size.
|
||||
@ray.remote
|
||||
def bar(i, size):
|
||||
array = np.zeros(size)
|
||||
array[0] = i
|
||||
return array
|
||||
|
||||
# Launch num_objects instances, half deterministic and half
|
||||
# nondeterministic.
|
||||
args = []
|
||||
for i in range(num_objects):
|
||||
if i % 2 == 0:
|
||||
args.append(foo.remote(i, size))
|
||||
else:
|
||||
args.append(bar.remote(i, size))
|
||||
|
||||
# Get each value to force each task to finish. After some number of gets,
|
||||
# old values should be evicted.
|
||||
for i in range(num_objects):
|
||||
value = ray.get(args[i])
|
||||
self.assertEqual(value[0], i)
|
||||
# Get each value again to force reconstruction.
|
||||
for i in range(num_objects):
|
||||
value = ray.get(args[i])
|
||||
self.assertEqual(value[0], i)
|
||||
|
||||
# Wait for errors from all the nondeterministic tasks.
|
||||
time_left = 100
|
||||
while time_left > 0:
|
||||
errors = ray.error_info()
|
||||
if len(errors) >= num_objects / 2:
|
||||
break
|
||||
time_left -= 0.1
|
||||
time.sleep(0.1)
|
||||
|
||||
# Make sure that enough errors came through.
|
||||
self.assertTrue(len(errors) >= num_objects / 2)
|
||||
# Make sure all the errors have the correct type.
|
||||
self.assertTrue(all(error[b"type"] == b"object_hash_mismatch" for error in errors))
|
||||
# Make sure all the errors have the correct function name.
|
||||
self.assertTrue(all(error[b"data"] == b"__main__.foo" for error in errors))
|
||||
|
||||
class ReconstructionTestsMultinode(ReconstructionTests):
|
||||
|
||||
# Run the same tests as the single-node suite, but with 4 local schedulers,
|
||||
|
||||
Reference in New Issue
Block a user