diff --git a/python/common/redis_module/runtest.py b/python/common/redis_module/runtest.py index 45ed3c4c3..0cfc02169 100644 --- a/python/common/redis_module/runtest.py +++ b/python/common/redis_module/runtest.py @@ -9,6 +9,7 @@ import sys import time import unittest import redis +import ray.services # Check if the redis-server binary is present. redis_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "../../core/src/common/thirdparty/redis/src/redis-server") @@ -42,17 +43,11 @@ def integerToAsciiHex(num, numbytes): class TestGlobalStateStore(unittest.TestCase): def setUp(self): - redis_port = random.randint(2000, 50000) - self.redis_process = subprocess.Popen([redis_path, - "--port", str(redis_port), - "--loglevel", "warning", - "--loadmodule", module_path]) - time.sleep(1.5) + redis_port = ray.services.start_redis() self.redis = redis.StrictRedis(host="localhost", port=redis_port, db=0) def tearDown(self): - self.redis_process.kill() - + ray.services.cleanup() def testInvalidObjectTableAdd(self): # Check that Redis returns an error when RAY.OBJECT_TABLE_ADD is called with @@ -81,7 +76,7 @@ class TestGlobalStateStore(unittest.TestCase): # Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been # added yet. response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") - self.assertEqual(set(response), set([])) + self.assertEqual(response, None) # Add some managers and try again. self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1") self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2") @@ -109,7 +104,7 @@ class TestGlobalStateStore(unittest.TestCase): # Try calling RAY.OBJECT_TABLE_LOOKUP with an object ID that has not been # added yet. response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") - self.assertEqual(set(response), set([])) + self.assertEqual(response, None) # Add some managers and try again. self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1") self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id2") @@ -131,7 +126,7 @@ class TestGlobalStateStore(unittest.TestCase): self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id2") response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") self.assertEqual(set(response), set()) - # Remove a manager from an empty set, and make sure we still have an empty set. + # Remove a manager from an empty set, and make sure we now have an empty set. self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", "object_id1", "manager_id3") response = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", "object_id1") self.assertEqual(set(response), set()) @@ -173,24 +168,19 @@ class TestGlobalStateStore(unittest.TestCase): self.redis.execute_command("RAY.OBJECT_TABLE_ADD", "object_id1", 1, "hash1", "manager_id1") response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1") self.assertIsNone(response) - # Add the result to the result table. This is necessary, but not sufficient - # because the task is still not in the task table. - self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", "task_id1") + # Add the result to the result table. The lookup now returns the task ID. + task_id = b"task_id1" + self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id1", task_id) response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1") - self.assertIsNone(response) - # Add the task to the task table so that the result table lookup can - # succeed. - self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id1", 1, "local_scheduler_id1", "task_spec1") - response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1") - self.assertEqual(response, [1, b"local_scheduler_id1", b"task_spec1"]) + self.assertEqual(response, task_id) # Doing it again should still work. response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id1") - self.assertEqual(response, [1, b"local_scheduler_id1", b"task_spec1"]) + self.assertEqual(response, task_id) # Try another result table lookup. This should succeed. - self.redis.execute_command("RAY.TASK_TABLE_ADD", "task_id2", 2, "local_scheduler_id2", "task_spec2") - self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", "task_id2") + task_id = b"task_id2" + self.redis.execute_command("RAY.RESULT_TABLE_ADD", "object_id2", task_id) response = self.redis.execute_command("RAY.RESULT_TABLE_LOOKUP", "object_id2") - self.assertEqual(response, [2, b"local_scheduler_id2", b"task_spec2"]) + self.assertEqual(response, task_id) def testInvalidTaskTableAdd(self): # Check that Redis returns an error when RAY.TASK_TABLE_ADD is called with @@ -227,6 +217,40 @@ class TestGlobalStateStore(unittest.TestCase): response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") self.assertEqual(response, task_args) + # If the current value, test value, and set value are all the same, the + # update happens, and the response is still the same task. + task_args = [task_args[0]] + task_args + response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", + "task_id", + *task_args[:3]) + self.assertEqual(response, task_args[1:]) + # Check that the task entry is still the same. + get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") + self.assertEqual(get_response, task_args[1:]) + + # If the current value is the same as the test value, and the set value is + # different, the update happens, and the response is the entire task. + task_args[1] += 1 + response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", + "task_id", + *task_args[:3]) + self.assertEqual(response, task_args[1:]) + # Check that the update happened. + get_response = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") + self.assertEqual(get_response, task_args[1:]) + + # If the current value is no longer the same as the test value, the + # response is nil. + task_args[1] += 1 + response = self.redis.execute_command("RAY.TASK_TABLE_TEST_AND_UPDATE", + "task_id", + *task_args[:3]) + self.assertEqual(response, None) + # Check that the update did not happen. + get_response2 = self.redis.execute_command("RAY.TASK_TABLE_GET", "task_id") + self.assertEqual(get_response2, get_response) + self.assertNotEqual(get_response2, task_args[1:]) + def testTaskTableSubscribe(self): scheduling_state = 1 node_id = "node_id" diff --git a/python/plasma/plasma.py b/python/plasma/plasma.py index aea9b993d..94cf92d45 100644 --- a/python/plasma/plasma.py +++ b/python/plasma/plasma.py @@ -151,7 +151,8 @@ class PlasmaClient(object): Args: object_ids (List[str]): A list of strings used to identify some objects. timeout_ms (int): The number of milliseconds that the get call should - block before timing out and returning. + block before timing out and returning. Pass -1 if the call should block + and 0 if the call should return immediately. """ results = libplasma.get(self.conn, object_ids, timeout_ms) assert len(object_ids) == len(results) @@ -172,7 +173,8 @@ class PlasmaClient(object): Args: object_ids (List[str]): A list of strings used to identify some objects. timeout_ms (int): The number of milliseconds that the get call should - block before timing out and returning. + block before timing out and returning. Pass -1 if the call should block + and 0 if the call should return immediately. """ results = libplasma.get(self.conn, object_ids, timeout_ms) assert len(object_ids) == len(results) diff --git a/python/ray/services.py b/python/ray/services.py index 4fc524f1e..71aa106fe 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -55,6 +55,13 @@ ObjectStoreAddress = namedtuple("ObjectStoreAddress", ["name", def address(ip_address, port): return ip_address + ":" + str(port) +def get_ip_address(address): + try: + ip_address = address.split(":")[0] + except: + raise Exception("Unable to parse IP address from address {}".format(address)) + return ip_address + def get_port(address): try: port = int(address.split(":")[1]) @@ -430,7 +437,8 @@ def start_ray_processes(address_info=None, # A Redis address was provided, so start a Redis server with the given # port. TODO(rkn): We should check that the IP address corresponds to the # machine that this method is running on. - redis_ip_address, redis_port = redis_address.split(":") + redis_ip_address = get_ip_address(redis_address) + redis_port = get_port(redis_address) new_redis_port = start_redis(port=int(redis_port), num_retries=1, cleanup=cleanup, diff --git a/python/ray/worker.py b/python/ray/worker.py index 3d7ee5b03..aa19bd9a2 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -39,6 +39,10 @@ ERROR_KEY_PREFIX = b"Error:" DRIVER_ID_LENGTH = 20 ERROR_ID_LENGTH = 20 +# When performing ray.get, wait 1 second before attemping to reconstruct and +# fetch the object again. +GET_TIMEOUT_MILLISECONDS = 1000 + def random_string(): return np.random.bytes(20) @@ -421,13 +425,13 @@ class Worker(object): # Serialize and put the object in the object store. try: numbuf.store_list(objectid.id(), self.plasma_client.conn, [value]) - except plasma.plasma_object_exists_error as e: + except numbuf.numbuf_plasma_object_exists_error as e: # The object already exists in the object store, so there is no need to # add it again. TODO(rkn): We need to compare the hashes and make sure # that the objects are in fact the same. We also should return an error # code to the caller instead of printing a message. print("This object already exists in the object store.") - return + global contained_objectids # Optionally do something with the contained_objectids here. contained_objectids = [] @@ -443,18 +447,37 @@ class Worker(object): values should be retrieved. """ self.plasma_client.fetch([object_id.id() for object_id in object_ids]) - # We currently pass in a timeout of one second. - unready_ids = object_ids + + # Get the objects. We initially try to get the objects immediately. + final_results = numbuf.retrieve_list( + [object_id.id() for object_id in object_ids], + self.plasma_client.conn, + 0) + # Construct a dictionary mapping object IDs that we haven't gotten yet to + # their original index in the object_ids argument. + unready_ids = dict((object_id, i) for (i, (object_id, val)) in + enumerate(final_results) if val is None) + # Try reconstructing any objects we haven't gotten yet. Try to get them + # until GET_TIMEOUT_MILLISECONDS milliseconds passes, then repeat. while len(unready_ids) > 0: - results = numbuf.retrieve_list([object_id.id() for object_id in object_ids], self.plasma_client.conn, 1000) - unready_ids = [object_id for (object_id, val) in results if val is None] - # This would be a natural place to issue a command to reconstruct some of - # the objects. + for unready_id in unready_ids: + self.photon_client.reconstruct_object(unready_id) + results = numbuf.retrieve_list(list(unready_ids.keys()), + self.plasma_client.conn, + GET_TIMEOUT_MILLISECONDS) + # Remove any entries for objects we received during this iteration so we + # don't retrieve the same object twice. + for object_id, val in results: + if val is not None: + index = unready_ids[object_id] + final_results[index] = (object_id, val) + unready_ids.pop(object_id) + # Unwrap the object from the list (it was wrapped put_object). - assert len(results) == len(object_ids) - for i in range(len(results)): - assert results[i][0] == object_ids[i].id() - return [result[1][0] for result in results] + assert len(final_results) == len(object_ids) + for i in range(len(final_results)): + assert final_results[i][0] == object_ids[i].id() + return [result[1][0] for result in final_results] def submit_task(self, function_id, func_name, args): """Submit a remote task to the scheduler. diff --git a/src/common/lib/python/common_extension.c b/src/common/lib/python/common_extension.c index 4d49f91f2..4d43de959 100644 --- a/src/common/lib/python/common_extension.c +++ b/src/common/lib/python/common_extension.c @@ -35,6 +35,16 @@ void init_pickle_module(void) { /* Define the PyObjectID class. */ +int PyStringToUniqueID(PyObject *object, object_id *object_id) { + if (PyBytes_Check(object)) { + memcpy(&object_id->id[0], PyBytes_AsString(object), UNIQUE_ID_SIZE); + return 1; + } else { + PyErr_SetString(PyExc_TypeError, "must be a 20 character string"); + return 0; + } +} + int PyObjectToUniqueID(PyObject *object, object_id *objectid) { if (PyObject_IsInstance(object, (PyObject *) &PyObjectIDType)) { *objectid = ((PyObjectID *) object)->object_id; diff --git a/src/common/lib/python/common_extension.h b/src/common/lib/python/common_extension.h index 9a5cfe851..9637697f3 100644 --- a/src/common/lib/python/common_extension.h +++ b/src/common/lib/python/common_extension.h @@ -33,6 +33,8 @@ extern PyObject *pickle_loads; void init_pickle_module(void); +int PyStringToUniqueID(PyObject *object, object_id *object_id); + int PyObjectToUniqueID(PyObject *object, object_id *objectid); PyObject *PyObjectID_make(object_id object_id); diff --git a/src/common/redis_module/ray_redis_module.c b/src/common/redis_module/ray_redis_module.c index 831848282..94fc779c1 100644 --- a/src/common/redis_module/ray_redis_module.c +++ b/src/common/redis_module/ray_redis_module.c @@ -188,8 +188,9 @@ int GetClientAddress_RedisCommand(RedisModuleCtx *ctx, * RAY.OBJECT_TABLE_LOOKUP * * @param object_id A string representing the object ID. - * @return A list of plasma manager IDs that are listed in the object table as - * having the object. + * @return A list, possibly empty, of plasma manager IDs that are listed in the + * object table as having the object. If there was no entry found in + * the object table, returns nil. */ int ObjectTableLookup_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, @@ -201,8 +202,12 @@ int ObjectTableLookup_RedisCommand(RedisModuleCtx *ctx, RedisModuleKey *key = OpenPrefixedKey(ctx, OBJECT_LOCATION_PREFIX, argv[1], REDISMODULE_READ); - if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY || - RedisModule_ValueLength(key) == 0) { + if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) { + /* Return nil if no entry was found. */ + return RedisModule_ReplyWithNull(ctx); + } + if (RedisModule_ValueLength(key) == 0) { + /* Return empty list if there are no managers. */ return RedisModule_ReplyWithArray(ctx, 0); } @@ -581,6 +586,35 @@ int ResultTableAdd_RedisCommand(RedisModuleCtx *ctx, return REDISMODULE_OK; } +int ParseTaskState(RedisModuleString *state) { + size_t state_length; + const char *state_string = RedisModule_StringPtrLen(state, &state_length); + int state_integer; + int scanned = sscanf(state_string, "%2d", &state_integer); + if (scanned != 1 || state_length != 2) { + return -1; + } + return state_integer; +} + +RedisModuleString *NormalizeTaskState(RedisModuleCtx *ctx, + RedisModuleString *state) { + /* Pad the state integer to a fixed-width integer, and make sure it has width + * less than or equal to 2. */ + long long state_integer; + int status = RedisModule_StringToLongLong(state, &state_integer); + if (status != REDISMODULE_OK) { + return NULL; + } + state = RedisModule_CreateStringPrintf(ctx, "%2d", state_integer); + size_t length; + RedisModule_StringPtrLen(state, &length); + if (length != 2) { + return NULL; + } + return state; +} + /** * Reply with information about a task ID. This is used by * RAY.RESULT_TABLE_LOOKUP and RAY.TASK_TABLE_GET. @@ -609,11 +643,8 @@ int ReplyWithTask(RedisModuleCtx *ctx, RedisModuleString *task_id) { ctx, "Missing fields in the task table entry"); } - size_t state_length; - const char *state_string = RedisModule_StringPtrLen(state, &state_length); - int state_integer; - int scanned = sscanf(state_string, "%2d", &state_integer); - if (scanned != 1 || state_length != 2) { + int state_integer = ParseTaskState(state); + if (state_integer < 0) { RedisModule_CloseKey(key); RedisModule_FreeString(ctx, state); RedisModule_FreeString(ctx, local_scheduler_id); @@ -668,23 +699,21 @@ int ResultTableLookup_RedisCommand(RedisModuleCtx *ctx, key = OpenPrefixedKey(ctx, OBJECT_INFO_PREFIX, object_id, REDISMODULE_READ); if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) { + RedisModule_CloseKey(key); return RedisModule_ReplyWithNull(ctx); } RedisModuleString *task_id; RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "task", &task_id, NULL); + RedisModule_CloseKey(key); if (task_id == NULL) { return RedisModule_ReplyWithNull(ctx); } - /* Construct a reply by getting the task from the task ID. */ - int status = ReplyWithTask(ctx, task_id); - - /* Clean up. */ + RedisModule_ReplyWithString(ctx, task_id); RedisModule_FreeString(ctx, task_id); - RedisModule_CloseKey(key); - return status; + return REDISMODULE_OK; } int TaskTableWrite(RedisModuleCtx *ctx, @@ -694,18 +723,11 @@ int TaskTableWrite(RedisModuleCtx *ctx, RedisModuleString *task_spec) { /* Pad the state integer to a fixed-width integer, and make sure it has width * less than or equal to 2. */ - long long state_integer; - int status = RedisModule_StringToLongLong(state, &state_integer); - if (status != REDISMODULE_OK) { + state = NormalizeTaskState(ctx, state); + if (state == NULL) { return RedisModule_ReplyWithError( - ctx, "Invalid scheduling state (must be an integer)"); - } - state = RedisModule_CreateStringPrintf(ctx, "%2d", state_integer); - size_t length; - RedisModule_StringPtrLen(state, &length); - if (length != 2) { - return RedisModule_ReplyWithError( - ctx, "Invalid scheduling state width (must have width 2)"); + ctx, + "Invalid scheduling state (must be an integer of width at most 2)"); } /* Add the task to the task table. If no spec was provided, get the existing @@ -720,6 +742,7 @@ int TaskTableWrite(RedisModuleCtx *ctx, &existing_task_spec, NULL); if (existing_task_spec == NULL) { RedisModule_CloseKey(key); + RedisModule_FreeString(ctx, state); return RedisModule_ReplyWithError( ctx, "Cannot update a task that doesn't exist yet"); } @@ -743,6 +766,7 @@ int TaskTableWrite(RedisModuleCtx *ctx, publish_message = RedisString_Format(ctx, "%S %S %S %S", task_id, state, node_id, existing_task_spec); } + RedisModule_FreeString(ctx, state); RedisModuleCallReply *reply = RedisModule_Call(ctx, "PUBLISH", "ss", publish_topic, publish_message); @@ -820,6 +844,93 @@ int TaskTableUpdate_RedisCommand(RedisModuleCtx *ctx, return TaskTableWrite(ctx, argv[1], argv[2], argv[3], NULL); } +/** + * Test and update an entry in the task table if the current value matches the + * test value. This does not update the task specification in the table. + * + * This is called from a client with the command: + * + * RAY.TASK_TABLE_TEST_AND_UPDATE + * + * + * @param task_id A string that is the ID of the task. + * @param test_state A string that is the test value for the scheduling state. + * The update happens if and only if the current scheduling state + * matches this value. + * @param state A string that is the scheduling state (a scheduling_state enum + * instance) to update the task entry with. The string's value must be a + * nonnegative integer less than 100, so that it has width at most 2. If + * less than 2, the value will be left-padded with spaces to a width of + * 2. + * @param ray_client_id A string that is the ray client ID of the associated + * local scheduler, if any, to update the task entry with. + * @return If the current scheduling state does not match the test value, + * returns nil. Else, returns the same as RAY.TASK_TABLE_GET: an array + * of strings representing the updated task fields in the following + * order: 1) (integer) scheduling state 2) (string) associated node ID, + * if any 3) (string) the task specification, which can be casted to a + * task_spec. + */ +int TaskTableTestAndUpdate_RedisCommand(RedisModuleCtx *ctx, + RedisModuleString **argv, + int argc) { + if (argc != 5) { + return RedisModule_WrongArity(ctx); + } + + RedisModuleString *state = NormalizeTaskState(ctx, argv[3]); + if (state == NULL) { + return RedisModule_ReplyWithError( + ctx, + "Invalid scheduling state (must be an integer of width at most 2)"); + } + + RedisModuleKey *key = OpenPrefixedKey(ctx, TASK_PREFIX, argv[1], + REDISMODULE_READ | REDISMODULE_WRITE); + if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) { + RedisModule_CloseKey(key); + RedisModule_FreeString(ctx, state); + return RedisModule_ReplyWithNull(ctx); + } + + /* If the key exists, look up the fields and return them in an array. */ + RedisModuleString *current_state = NULL; + RedisModule_HashGet(key, REDISMODULE_HASH_CFIELDS, "state", ¤t_state, + NULL); + int current_state_integer = ParseTaskState(current_state); + if (current_state_integer < 0) { + RedisModule_CloseKey(key); + RedisModule_FreeString(ctx, state); + return RedisModule_ReplyWithError(ctx, + "Found invalid scheduling state (must " + "be an integer of width 2"); + } + long long test_state_integer; + int status = RedisModule_StringToLongLong(argv[2], &test_state_integer); + if (status != REDISMODULE_OK) { + RedisModule_CloseKey(key); + RedisModule_FreeString(ctx, state); + return RedisModule_ReplyWithError( + ctx, "Invalid test value for scheduling state"); + } + if (current_state_integer != test_state_integer) { + /* The current value does not match the test value, so do not perform the + * update. */ + RedisModule_CloseKey(key); + RedisModule_FreeString(ctx, state); + return RedisModule_ReplyWithNull(ctx); + } + + /* The test passed, so perform the update. */ + RedisModule_HashSet(key, REDISMODULE_HASH_CFIELDS, "state", state, "node", + argv[4], NULL); + /* Clean up. */ + RedisModule_CloseKey(key); + RedisModule_FreeString(ctx, state); + /* Construct a reply by getting the task from the task ID. */ + return ReplyWithTask(ctx, argv[1]); +} + /** * Get an entry from the task table. * @@ -922,6 +1033,12 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, return REDISMODULE_ERR; } + if (RedisModule_CreateCommand(ctx, "ray.task_table_test_and_update", + TaskTableTestAndUpdate_RedisCommand, + "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + if (RedisModule_CreateCommand(ctx, "ray.task_table_get", TaskTableGet_RedisCommand, "readonly", 0, 0, 0) == REDISMODULE_ERR) { diff --git a/src/common/state/object_table.h b/src/common/state/object_table.h index 6ed364d35..80cca0bcb 100644 --- a/src/common/state/object_table.h +++ b/src/common/state/object_table.h @@ -11,7 +11,9 @@ */ /* Callback called when the lookup completes. The callback should free - * the manager_vector array, but NOT the strings they are pointing to. + * the manager_vector array, but NOT the strings they are pointing to. If there + * was no entry at all for the object (the object had never been created + * before), then manager_count will be -1. */ typedef void (*object_table_lookup_done_callback)( object_id object_id, @@ -234,11 +236,12 @@ void result_table_add(db_handle *db_handle, /** Callback called when the result table lookup completes. */ typedef void (*result_table_lookup_callback)(object_id object_id, - task *task, + task_id task_id, void *user_context); /** - * Lookup the task that created an object in the result table. + * Lookup the task that created an object in the result table. The return value + * is the task ID. * * @param db_handle Handle to object_table database. * @param object_id ID of the object to lookup. diff --git a/src/common/state/redis.c b/src/common/state/redis.c index 07def309e..52c6a6b4c 100644 --- a/src/common/state/redis.c +++ b/src/common/state/redis.c @@ -401,16 +401,20 @@ void redis_result_table_lookup_callback(redisAsyncContext *c, void *privdata) { REDIS_CALLBACK_HEADER(db, callback_data, r); redisReply *reply = r; + CHECKM(reply->type == REDIS_REPLY_NIL || reply->type == REDIS_REPLY_STRING, + "Unexpected reply type %d in redis_result_table_lookup_callback", + reply->type); /* Parse the task from the reply. */ - task *task = parse_and_construct_task_from_redis_reply(reply); + task_id result_id = NIL_TASK_ID; + if (reply->type == REDIS_REPLY_STRING) { + CHECK(reply->len == sizeof(result_id)); + memcpy(&result_id, reply->str, reply->len); + } + /* Call the done callback if there is one. */ result_table_lookup_callback done_callback = callback_data->done_callback; if (done_callback != NULL) { - done_callback(callback_data->id, task, callback_data->user_context); - } - /* Free the task if it is not NULL. */ - if (task != NULL) { - free_task(task); + done_callback(callback_data->id, result_id, callback_data->user_context); } /* Clean up timer and callback. */ destroy_timer_callback(db->loop, callback_data); @@ -465,24 +469,33 @@ void redis_object_table_lookup_callback(redisAsyncContext *c, void *privdata) { REDIS_CALLBACK_HEADER(db, callback_data, r); redisReply *reply = r; + LOG_DEBUG("Object table lookup callback"); + CHECK(reply->type == REDIS_REPLY_NIL || reply->type == REDIS_REPLY_ARRAY); object_id obj_id = callback_data->id; - - LOG_DEBUG("Object table lookup callback"); - CHECK(reply->type == REDIS_REPLY_ARRAY); - - int64_t manager_count = reply->elements; + int64_t manager_count = 0; db_client_id *managers = NULL; const char **manager_vector = NULL; - if (manager_count > 0) { - managers = malloc(reply->elements * sizeof(db_client_id)); - manager_vector = malloc(manager_count * sizeof(char *)); - } - for (int j = 0; j < reply->elements; ++j) { - CHECK(reply->element[j]->type == REDIS_REPLY_STRING); - memcpy(managers[j].id, reply->element[j]->str, sizeof(managers[j].id)); - redis_get_cached_db_client(db, managers[j], manager_vector + j); + + /* Parse the Redis reply. */ + if (reply->type == REDIS_REPLY_NIL) { + /* The object entry did not exist. */ + manager_count = -1; + } else if (reply->type == REDIS_REPLY_ARRAY) { + manager_count = reply->elements; + if (manager_count > 0) { + managers = malloc(reply->elements * sizeof(db_client_id)); + manager_vector = malloc(manager_count * sizeof(char *)); + } + for (int j = 0; j < reply->elements; ++j) { + CHECK(reply->element[j]->type == REDIS_REPLY_STRING); + memcpy(managers[j].id, reply->element[j]->str, sizeof(managers[j].id)); + redis_get_cached_db_client(db, managers[j], manager_vector + j); + } + } else { + LOG_FATAL("Unexpected reply type from object table lookup."); } + object_table_lookup_done_callback done_callback = callback_data->done_callback; if (done_callback) { @@ -821,6 +834,43 @@ void redis_task_table_update(table_callback_data *callback_data) { } } +void redis_task_table_test_and_update_callback(redisAsyncContext *c, + void *r, + void *privdata) { + REDIS_CALLBACK_HEADER(db, callback_data, r); + redisReply *reply = r; + /* Parse the task from the reply. */ + task *task = parse_and_construct_task_from_redis_reply(reply); + /* Call the done callback if there is one. */ + task_table_get_callback done_callback = callback_data->done_callback; + if (done_callback != NULL) { + done_callback(task, callback_data->user_context); + } + /* Free the task if it is not NULL. */ + if (task != NULL) { + free_task(task); + } + /* Clean up timer and callback. */ + destroy_timer_callback(db->loop, callback_data); +} + +void redis_task_table_test_and_update(table_callback_data *callback_data) { + db_handle *db = callback_data->db_handle; + task_id task_id = callback_data->id; + task_table_test_and_update_data *update_data = callback_data->data; + + int status = redisAsyncCommand( + db->context, redis_task_table_test_and_update_callback, + (void *) callback_data->timer_id, + "RAY.TASK_TABLE_TEST_AND_UPDATE %b %d %d %b", task_id.id, + sizeof(task_id.id), update_data->test_state, update_data->update_state, + update_data->local_scheduler_id.id, + sizeof(update_data->local_scheduler_id.id)); + if ((status == REDIS_ERR) || db->context->err) { + LOG_REDIS_DEBUG(db->context, "error in redis_task_table_test_and_update"); + } +} + /* The format of the payload is described in ray_redis_module.c and is * " ". TODO(rkn): * Make this code nicer. */ diff --git a/src/common/state/redis.h b/src/common/state/redis.h index 73061dcf6..1d4e6b780 100644 --- a/src/common/state/redis.h +++ b/src/common/state/redis.h @@ -121,8 +121,7 @@ void redis_object_table_request_notifications( void redis_result_table_add(table_callback_data *callback_data); /** - * Lookup the object in the object table in redis. The entry in - * the object table contains metadata about the object. + * Lookup the task that created the object in redis. The result is the task ID. * * @param callback_data Data structure containing redis connection and timeout * information. @@ -176,6 +175,16 @@ void redis_task_table_add_task(table_callback_data *callback_data); */ void redis_task_table_update(table_callback_data *callback_data); +/** + * Update a task table entry with the task's scheduling information, if the + * task's current scheduling information matches the test value. + * + * @param callback_data Data structure containing redis connection and timeout + * information. + * @return Void. + */ +void redis_task_table_test_and_update(table_callback_data *callback_data); + /** * Callback invoked when the reply from the task push command is received. * diff --git a/src/common/state/task_table.c b/src/common/state/task_table.c index 4e67e4bf6..6a95ef1d7 100644 --- a/src/common/state/task_table.c +++ b/src/common/state/task_table.c @@ -30,6 +30,24 @@ void task_table_update(db_handle *db_handle, done_callback, redis_task_table_update, user_context); } +void task_table_test_and_update(db_handle *db_handle, + task_id task_id, + scheduling_state test_state, + scheduling_state update_state, + retry_info *retry, + task_table_get_callback done_callback, + void *user_context) { + task_table_test_and_update_data *update_data = + malloc(sizeof(task_table_test_and_update_data)); + update_data->test_state = test_state; + update_data->update_state = update_state; + /* Update the task entry's local scheduler with this client's ID. */ + update_data->local_scheduler_id = db_handle->client; + init_table_callback(db_handle, task_id, __func__, update_data, retry, + done_callback, redis_task_table_test_and_update, + user_context); +} + /* TODO(swang): A corresponding task_table_unsubscribe. */ void task_table_subscribe(db_handle *db_handle, db_client_id local_scheduler_id, diff --git a/src/common/state/task_table.h b/src/common/state/task_table.h index 2cf8b4eef..5068f237d 100644 --- a/src/common/state/task_table.h +++ b/src/common/state/task_table.h @@ -87,6 +87,40 @@ void task_table_update(db_handle *db_handle, task_table_done_callback done_callback, void *user_context); +/** + * Update a task's scheduling information in the task table, if the current + * value matches the given test value. If the update succeeds, it also updates + * the task entry's local scheduler ID with the ID of the client who called + * this function. This assumes that the task spec already exists in the task + * table entry. + * + * @param db_handle Database handle. + * @param task_id The task ID of the task entry to update. + * @param test_state The value to test the current task entry's scheduling + * state against. + * @param update_state The value to update the task entry's scheduling state + * with, if the current state matches test_state. + * @param retry Information about retrying the request to the database. + * @param done_callback Function to be called when database returns result. + * @param user_context Data that will be passed to done_callback and + * fail_callback. + * @return Void. + */ +void task_table_test_and_update(db_handle *db_handle, + task_id task_id, + scheduling_state test_state, + scheduling_state update_state, + retry_info *retry, + task_table_get_callback done_callback, + void *user_context); + +/* Data that is needed to test and set the task's scheduling state. */ +typedef struct { + scheduling_state test_state; + scheduling_state update_state; + db_client_id local_scheduler_id; +} task_table_test_and_update_data; + /* * ==== Subscribing to the task table ==== */ diff --git a/src/common/task.c b/src/common/task.c index c18061ae3..1c9fa0e24 100644 --- a/src/common/task.c +++ b/src/common/task.c @@ -173,14 +173,6 @@ void finish_construct_task_spec(task_spec *spec) { } } -task_spec *alloc_nil_task_spec(task_id task_id) { - task_spec *spec = - start_construct_task_spec(NIL_ID, NIL_ID, 0, NIL_FUNCTION_ID, 0, 0, 0); - finish_construct_task_spec(spec); - spec->task_id = task_id; - return spec; -} - int64_t task_spec_size(task_spec *spec) { return TASK_SPEC_SIZE(spec->num_args, spec->num_returns, spec->args_value_size); @@ -332,13 +324,6 @@ task *copy_task(task *other) { return copy; } -task *alloc_nil_task(task_id task_id) { - task_spec *nil_spec = alloc_nil_task_spec(task_id); - task *nil_task = alloc_task(nil_spec, 0, NIL_ID); - free_task_spec(nil_spec); - return nil_task; -} - int64_t task_size(task *task_arg) { return sizeof(task) - sizeof(task_spec) + task_spec_size(&task_arg->spec); } diff --git a/src/common/task.h b/src/common/task.h index e60065589..45094852a 100644 --- a/src/common/task.h +++ b/src/common/task.h @@ -273,7 +273,9 @@ typedef enum { /** The task is running on a worker. */ TASK_STATUS_RUNNING = 8, /** The task is done executing. */ - TASK_STATUS_DONE = 16 + TASK_STATUS_DONE = 16, + /** The task will be submitted for reexecution. */ + TASK_STATUS_RECONSTRUCTING = 32 } scheduling_state; /** A task is an execution of a task specification. It has a state of execution @@ -325,14 +327,4 @@ task_id task_task_id(task *task); /** Free this task datastructure. */ void free_task(task *task); -/** - * ==== Task update ==== - * Contains the information necessary to update a task in the task log. - */ - -typedef struct { - scheduling_state state; - db_client_id local_scheduler_id; -} task_update; - #endif diff --git a/src/common/test/object_table_tests.c b/src/common/test/object_table_tests.c index 39c337ff2..5aa80fcf8 100644 --- a/src/common/test/object_table_tests.c +++ b/src/common/test/object_table_tests.c @@ -31,12 +31,11 @@ void new_object_fail_callback(unique_id id, /* === Test adding an object with an associated task === */ void new_object_done_callback(object_id object_id, - task *task, + task_id task_id, void *user_context) { new_object_succeeded = 1; CHECK(object_ids_equal(object_id, new_object_id)); - CHECK(task); - CHECK(memcmp(task, new_object_task, task_size(task)) == 0); + CHECK(task_ids_equal(task_id, new_object_task_id)); event_loop_stop(g_loop); } @@ -92,26 +91,14 @@ TEST new_object_test(void) { /* === Test adding an object without an associated task === */ -void new_object_no_task_lookup_callback(object_id object_id, - task *task, - void *user_context) { +void new_object_no_task_callback(object_id object_id, + task_id task_id, + void *user_context) { new_object_succeeded = 1; - CHECK(task == NULL); + CHECK(IS_NIL_ID(task_id)); event_loop_stop(g_loop); } -void new_object_no_task_callback(object_id object_id, void *user_context) { - CHECK(object_ids_equal(object_id, new_object_id)); - retry_info retry = { - .num_retries = 5, - .timeout = 100, - .fail_callback = new_object_fail_callback, - }; - db_handle *db = user_context; - result_table_lookup(db, object_id, &retry, new_object_no_task_lookup_callback, - NULL); -} - TEST new_object_no_task_test(void) { new_object_failed = 0; new_object_succeeded = 0; @@ -126,8 +113,8 @@ TEST new_object_no_task_test(void) { .timeout = 100, .fail_callback = new_object_fail_callback, }; - result_table_add(db, new_object_id, new_object_task_id, &retry, - new_object_no_task_callback, db); + result_table_lookup(db, new_object_id, &retry, new_object_no_task_callback, + NULL); event_loop_run(g_loop); db_disconnect(db); destroy_outstanding_callbacks(g_loop); diff --git a/src/numbuf/CMakeLists.txt b/src/numbuf/CMakeLists.txt index a7e6bf5fb..43cbc23ab 100644 --- a/src/numbuf/CMakeLists.txt +++ b/src/numbuf/CMakeLists.txt @@ -48,6 +48,7 @@ if(HAS_PLASMA) include_directories("${CMAKE_CURRENT_LIST_DIR}/../common") include_directories("${CMAKE_CURRENT_LIST_DIR}/../common/thirdparty") include_directories("${CMAKE_CURRENT_LIST_DIR}/../common/build/flatcc-prefix/src/flatcc/include") + set(COMMON_EXTENSION ../common/lib/python/common_extension.c) endif() add_definitions(-fPIC) @@ -58,7 +59,8 @@ add_library(numbuf SHARED cpp/src/numbuf/sequence.cc python/src/pynumbuf/numbuf.cc python/src/pynumbuf/adapters/numpy.cc - python/src/pynumbuf/adapters/python.cc) + python/src/pynumbuf/adapters/python.cc + ${COMMON_EXTENSION}) get_filename_component(PYTHON_SHARED_LIBRARY ${PYTHON_LIBRARIES} NAME) if(APPLE) diff --git a/src/numbuf/python/src/pynumbuf/numbuf.cc b/src/numbuf/python/src/pynumbuf/numbuf.cc index d54ec6091..ba0246a89 100644 --- a/src/numbuf/python/src/pynumbuf/numbuf.cc +++ b/src/numbuf/python/src/pynumbuf/numbuf.cc @@ -193,6 +193,7 @@ static PyObject* register_callbacks(PyObject* self, PyObject* args) { #ifdef HAS_PLASMA +#include "common_extension.h" #include "plasma_extension.h" /** @@ -297,9 +298,12 @@ static PyObject* store_list(PyObject* self, PyObject* args) { * Python objects from the plasma data according to the schema and * returns the object. * - * @param args The first argument is a list of object IDs of the lists to be - * retrieved and the second argument is the connection to the plasma - * store. + * @param args The arguments are, in order: + * 1) A list of object IDs of the lists to be retrieved. + * 2) The connection to the plasma store. + * 3) A timeout in milliseconds that the call should return by. This is + * -1 if the call should block forever, or 0 if the call should + * return immediately. * @return A list of tuples, where the first element in the tuple is the object * ID (appearing in the same order as in the argument to the method), * and the second element in the tuple is the retrieved list (or None) @@ -436,7 +440,7 @@ MOD_INIT(libnumbuf) { PyErr_NewException(numbuf_plasma_object_exists_error, NULL, NULL); Py_INCREF(NumbufPlasmaObjectExistsError); PyModule_AddObject( - m, "pnumbuf_lasma_object_exists_error", NumbufPlasmaObjectExistsError); + m, "numbuf_plasma_object_exists_error", NumbufPlasmaObjectExistsError); /* Create a custom exception for when the plasma store is out of memory. */ char numbuf_plasma_out_of_memory_error[] = "numbuf_plasma_out_of_memory.error"; NumbufPlasmaOutOfMemoryError = diff --git a/src/photon/photon_algorithm.c b/src/photon/photon_algorithm.c index 73d9ba42f..bc156096f 100644 --- a/src/photon/photon_algorithm.c +++ b/src/photon/photon_algorithm.c @@ -233,6 +233,9 @@ int fetch_object_timeout_handler(event_loop *loop, timer_id id, void *context) { ++i; } plasma_fetch(state->plasma_conn, num_object_ids, object_ids); + for (int i = 0; i < num_object_ids; ++i) { + reconstruct_object(state, object_ids[i]); + } free(object_ids); return LOCAL_SCHEDULER_FETCH_TIMEOUT_MILLISECONDS; } diff --git a/src/photon/photon_extension.c b/src/photon/photon_extension.c index 7b9f752d1..428a8cf04 100644 --- a/src/photon/photon_extension.c +++ b/src/photon/photon_extension.c @@ -54,7 +54,7 @@ static PyObject *PyPhotonClient_get_task(PyObject *self) { static PyObject *PyPhotonClient_reconstruct_object(PyObject *self, PyObject *args) { object_id object_id; - if (!PyArg_ParseTuple(args, "O&", &PyObjectToUniqueID, &object_id)) { + if (!PyArg_ParseTuple(args, "O&", PyStringToUniqueID, &object_id)) { return NULL; } photon_reconstruct_object(((PyPhotonClient *) self)->photon_connection, diff --git a/src/photon/photon_scheduler.c b/src/photon/photon_scheduler.c index 89fc4dd0c..3ef2be28a 100644 --- a/src/photon/photon_scheduler.c +++ b/src/photon/photon_scheduler.c @@ -177,45 +177,61 @@ void process_plasma_notification(event_loop *loop, } } -void reconstruct_object_task_lookup_callback(object_id reconstruct_object_id, - task *task, - void *user_context) { - /* Recursively resubmit the task and its task lineage to the scheduler. */ - CHECKM(task != NULL, - "No task information found for object during reconstruction"); - local_scheduler_state *state = user_context; - /* If the task's scheduling state is pending completion, assume that - * reconstruction is already being taken care of and cancel this - * reconstruction operation. NOTE: This codepath is not responsible for - * detecting failure of the other reconstruction, or updating the - * scheduling_state accordingly. */ - scheduling_state task_status = task_state(task); - if (task_status != TASK_STATUS_DONE) { - LOG_DEBUG("Task to reconstruct had scheduling state %d", task_status); +void reconstruct_task_update_callback(task *task, void *user_context) { + if (task == NULL) { + /* The test-and-set of the task's scheduling state failed, so the task was + * either not finished yet, or it was already being reconstructed. + * Suppress the reconstruction request. */ return; } - /* Recursively reconstruct the task's inputs, if necessary. */ + /* Otherwise, the test-and-set succeeded, so resubmit the task for execution + * to ensure that reconstruction will happen. */ + local_scheduler_state *state = user_context; task_spec *spec = task_task_spec(task); - for (int64_t i = 0; i < task_num_args(spec); ++i) { - object_id arg_id = task_arg_id(spec, i); - reconstruct_object(state, arg_id); - } handle_task_submitted(state, state->algorithm_state, spec); + + /* Recursively reconstruct the task's inputs, if necessary. */ + for (int64_t i = 0; i < task_num_args(spec); ++i) { + if (task_arg_type(spec, i) == ARG_BY_REF) { + object_id arg_id = task_arg_id(spec, i); + reconstruct_object(state, arg_id); + } + } } -void reconstruct_object_object_lookup_callback(object_id reconstruct_object_id, - int manager_count, - const char *manager_vector[], - void *user_context) { +void reconstruct_result_lookup_callback(object_id reconstruct_object_id, + task_id task_id, + void *user_context) { + /* TODO(swang): The following check will fail if an object was created by a + * put. */ + CHECKM(!IS_NIL_ID(task_id), + "No task information found for object during reconstruction"); + local_scheduler_state *state = user_context; + /* Try to claim the responsibility for reconstruction by doing a test-and-set + * of the task's scheduling state in the global state. If the task's + * scheduling state is pending completion, assume that reconstruction is + * already being taken care of. NOTE: This codepath is not responsible for + * detecting failure of the other reconstruction, or updating the + * scheduling_state accordingly. */ + task_table_test_and_update( + state->db, task_id, TASK_STATUS_DONE, TASK_STATUS_RECONSTRUCTING, + (retry_info *) &photon_retry, reconstruct_task_update_callback, state); +} + +void reconstruct_object_lookup_callback(object_id reconstruct_object_id, + int manager_count, + const char *manager_vector[], + void *user_context) { + LOG_DEBUG("Manager count was %d", manager_count); /* Only continue reconstruction if we find that the object doesn't exist on * any nodes. NOTE: This codepath is not responsible for checking if the * object table entry is up-to-date. */ local_scheduler_state *state = user_context; if (manager_count == 0) { /* Look up the task that created the object in the result table. */ - result_table_lookup( - state->db, reconstruct_object_id, (retry_info *) &photon_retry, - reconstruct_object_task_lookup_callback, (void *) state); + result_table_lookup(state->db, reconstruct_object_id, + (retry_info *) &photon_retry, + reconstruct_result_lookup_callback, (void *) state); } } @@ -226,9 +242,9 @@ void reconstruct_object(local_scheduler_state *state, CHECK(state->db != NULL); /* Determine if reconstruction is necessary by checking if the object exists * on a node. */ - object_table_lookup( - state->db, reconstruct_object_id, (retry_info *) &photon_retry, - reconstruct_object_object_lookup_callback, (void *) state); + object_table_lookup(state->db, reconstruct_object_id, + (retry_info *) &photon_retry, + reconstruct_object_lookup_callback, (void *) state); } void process_message(event_loop *loop, diff --git a/src/photon/test/photon_tests.c b/src/photon/test/photon_tests.c index edea72140..a4983b143 100644 --- a/src/photon/test/photon_tests.c +++ b/src/photon/test/photon_tests.c @@ -101,6 +101,21 @@ TEST object_reconstruction_test(void) { photon_mock *photon = init_photon_mock(true); /* Create a task with zero dependencies and one return value. */ task_spec *spec = example_task_spec(0, 1); + object_id return_id = task_return(spec, 0); + + /* Add an empty object table entry for the object we want to reconstruct, to + * simulate it having been created and evicted. */ + const char *client_id = "clientid"; + redisContext *context = redisConnect("127.0.0.1", 6379); + redisReply *reply = redisCommand(context, "RAY.OBJECT_TABLE_ADD %b %ld %b %s", + return_id.id, sizeof(return_id.id), 1, + NIL_DIGEST, (size_t) DIGEST_SIZE, client_id); + freeReplyObject(reply); + reply = redisCommand(context, "RAY.OBJECT_TABLE_REMOVE %b %s", return_id.id, + sizeof(return_id.id), client_id); + freeReplyObject(reply); + redisFree(context); + pid_t pid = fork(); if (pid == 0) { /* Make sure we receive the task twice. First from the initial submission, @@ -163,6 +178,23 @@ TEST object_reconstruction_recursive_test(void) { photon->photon_state->algorithm_state, arg_id); specs[i] = example_task_spec_with_args(1, 1, &arg_id); } + + /* Add an empty object table entry for each object we want to reconstruct, to + * simulate their having been created and evicted. */ + const char *client_id = "clientid"; + redisContext *context = redisConnect("127.0.0.1", 6379); + for (int i = 0; i < NUM_TASKS; ++i) { + object_id return_id = task_return(specs[i], 0); + redisReply *reply = redisCommand( + context, "RAY.OBJECT_TABLE_ADD %b %ld %b %s", return_id.id, + sizeof(return_id.id), 1, NIL_DIGEST, (size_t) DIGEST_SIZE, client_id); + freeReplyObject(reply); + reply = redisCommand(context, "RAY.OBJECT_TABLE_REMOVE %b %s", return_id.id, + sizeof(return_id.id), client_id); + freeReplyObject(reply); + } + redisFree(context); + pid_t pid = fork(); if (pid == 0) { /* Submit the tasks, and make sure each one gets assigned to a worker. */ @@ -209,7 +241,8 @@ TEST object_reconstruction_recursive_test(void) { get_db_client_id(photon->photon_state->db)); task_table_add_task(photon->photon_state->db, last_task, (retry_info *) &photon_retry, NULL, NULL); - /* Trigger reconstruction, and run the event loop again. */ + /* Trigger reconstruction for the last object, and run the event loop + * again. */ object_id return_id = task_return(specs[NUM_TASKS - 1], 0); photon_reconstruct_object(photon->conn, return_id); event_loop_add_timer(photon->loop, 500, diff --git a/src/plasma/CMakeLists.txt b/src/plasma/CMakeLists.txt index 5760ee382..11fe1fc73 100644 --- a/src/plasma/CMakeLists.txt +++ b/src/plasma/CMakeLists.txt @@ -37,6 +37,7 @@ include_directories("${CMAKE_CURRENT_LIST_DIR}/../") add_library(plasma SHARED plasma.c plasma_extension.c + ../common/lib/python/common_extension.c plasma_protocol.c plasma_client.c thirdparty/xxhash.c diff --git a/src/plasma/eviction_policy.c b/src/plasma/eviction_policy.c index ec9af2c1f..477ed68d4 100644 --- a/src/plasma/eviction_policy.c +++ b/src/plasma/eviction_policy.c @@ -172,7 +172,6 @@ bool require_space(eviction_state *eviction_state, num_bytes_evicted = choose_objects_to_evict( eviction_state, plasma_store_info, space_to_free, num_objects_to_evict, objects_to_evict); - printf("Evicted %" PRId64 " bytes.\n", num_bytes_evicted); LOG_INFO( "There is not enough space to create this object, so evicting " "%" PRId64 " objects to free up %" PRId64 " bytes.\n", diff --git a/src/plasma/plasma_extension.c b/src/plasma/plasma_extension.c index b8960d14a..63b95049a 100644 --- a/src/plasma/plasma_extension.c +++ b/src/plasma/plasma_extension.c @@ -1,6 +1,7 @@ #include #include "bytesobject.h" +#include "common_extension.h" #include "common.h" #include "io.h" #include "plasma_protocol.h" diff --git a/src/plasma/plasma_extension.h b/src/plasma/plasma_extension.h index a5adc9785..0abf746f4 100644 --- a/src/plasma/plasma_extension.h +++ b/src/plasma/plasma_extension.h @@ -12,14 +12,4 @@ static int PyObjectToPlasmaConnection(PyObject *object, } } -static int PyStringToUniqueID(PyObject *object, object_id *object_id) { - if (PyBytes_Check(object)) { - memcpy(&object_id->id[0], PyBytes_AsString(object), UNIQUE_ID_SIZE); - return 1; - } else { - PyErr_SetString(PyExc_TypeError, "must be a 20 character string"); - return 0; - } -} - #endif /* PLASMA_EXTENSION_H */ diff --git a/test/stress_tests.py b/test/stress_tests.py index ca242ee99..9bf8ea846 100644 --- a/test/stress_tests.py +++ b/test/stress_tests.py @@ -6,6 +6,7 @@ import unittest import ray import numpy as np import time +import redis class TaskTests(unittest.TestCase): @@ -108,5 +109,177 @@ class TaskTests(unittest.TestCase): self.assertTrue(ray.services.all_processes_alive()) ray.worker.cleanup() +class ReconstructionTests(unittest.TestCase): + + num_local_schedulers = 1 + + def setUp(self): + # Start a Redis instance and Plasma store instances with a total of 1GB + # memory. + node_ip_address = "127.0.0.1" + self.redis_port = ray.services.new_port() + print(self.redis_port) + redis_address = ray.services.address(node_ip_address, self.redis_port) + self.plasma_store_memory = 10 ** 9 + plasma_addresses = [] + objstore_memory = (self.plasma_store_memory // self.num_local_schedulers) + for i in range(self.num_local_schedulers): + plasma_addresses.append( + ray.services.start_objstore(node_ip_address, redis_address, + objstore_memory=objstore_memory) + ) + address_info = { + "redis_address": redis_address, + "object_store_addresses": plasma_addresses, + } + + # Start the rest of the services in the Ray cluster. + ray.worker._init(address_info=address_info, start_ray_local=True, + num_workers=self.num_local_schedulers, num_local_schedulers=self.num_local_schedulers) + + def tearDown(self): + self.assertTrue(ray.services.all_processes_alive()) + + # Make sure that all nodes in the cluster were used by checking where tasks + # were scheduled and/or submitted from. + r = redis.StrictRedis(port=self.redis_port) + task_ids = r.keys("TT:*") + task_ids = [task_id[3:] for task_id in task_ids] + node_ids = [r.execute_command("ray.task_table_get", task_id)[1] for task_id + in task_ids] + self.assertEqual(len(set(node_ids)), self.num_local_schedulers) + + # Clean up the Ray cluster. + ray.worker.cleanup() + + def testSimple(self): + # Define the size of one task's return argument so that the combined sum of + # all objects' sizes is at least twice the plasma stores' combined allotted + # memory. + num_objects = 1000 + size = self.plasma_store_memory * 2 // (num_objects * 8) + + # Define a remote task with no dependencies, which returns a numpy array of + # the given size. + @ray.remote + def foo(i, size): + array = np.zeros(size) + array[0] = i + return array + + # Launch num_objects instances of the remote task. + args = [] + for i in range(num_objects): + args.append(foo.remote(i, size)) + + # Get each value to force each task to finish. After some number of gets, + # old values should be evicted. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get each value again to force reconstruction. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + + def testRecursive(self): + # Define the size of one task's return argument so that the combined sum of + # all objects' sizes is at least twice the plasma stores' combined allotted + # memory. + num_objects = 1000 + size = self.plasma_store_memory * 2 // (num_objects * 8) + + # Define a root task with no dependencies, which returns a numpy array of + # the given size. + @ray.remote + def no_dependency_task(size): + array = np.zeros(size) + return array + + # Define a task with a single dependency, which returns its one argument. + @ray.remote + def single_dependency(i, arg): + arg = np.copy(arg) + arg[0] = i + return arg + + # Launch num_objects instances of the remote task, each dependent on the + # one before it. + arg = no_dependency_task.remote(size) + args = [] + for i in range(num_objects): + arg = single_dependency.remote(i, arg) + args.append(arg) + + # Get each value to force each task to finish. After some number of gets, + # old values should be evicted. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get each value again to force reconstruction. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get 10 values randomly. + for _ in range(10): + i = np.random.randint(num_objects) + value = ray.get(args[i]) + self.assertEqual(value[0], i) + + def testMultipleRecursive(self): + # Define the size of one task's return argument so that the combined sum of + # all objects' sizes is at least twice the plasma stores' combined allotted + # memory. + num_objects = 1000 + size = self.plasma_store_memory * 2 // (num_objects * 8) + + # Define a root task with no dependencies, which returns a numpy array of + # the given size. + @ray.remote + def no_dependency_task(size): + array = np.zeros(size) + return array + + # Define a task with multiple dependencies, which returns its first + # argument. + @ray.remote + def multiple_dependency(i, arg1, arg2, arg3): + arg1 = np.copy(arg1) + arg1[0] = i + return arg1 + + # Launch num_args instances of the root task. Then launch num_objects + # instances of the multi-dependency remote task, each dependent on the + # num_args tasks before it. + num_args = 3 + args = [] + for i in range(num_args): + arg = no_dependency_task.remote(size) + args.append(arg) + for i in range(num_objects): + args.append(multiple_dependency.remote(i, *args[i:i + num_args])) + + # Get each value to force each task to finish. After some number of gets, + # old values should be evicted. + args = args[num_args:] + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get each value again to force reconstruction. + for i in range(num_objects): + value = ray.get(args[i]) + self.assertEqual(value[0], i) + # Get 10 values randomly. + for _ in range(10): + i = np.random.randint(num_objects) + value = ray.get(args[i]) + self.assertEqual(value[0], i) + +class ReconstructionTestsMultinode(ReconstructionTests): + + # Run the same tests as the single-node suite, but with 4 local schedulers, + # one worker each. + num_local_schedulers = 4 + if __name__ == "__main__": unittest.main(verbosity=2)