diff --git a/python/ray/global_scheduler/test/test.py b/python/ray/global_scheduler/test/test.py index cdf6f651d..3f9bbc731 100644 --- a/python/ray/global_scheduler/test/test.py +++ b/python/ray/global_scheduler/test/test.py @@ -90,7 +90,8 @@ class TestGlobalScheduler(unittest.TestCase): redis_address=redis_address, static_resource_list=[10, 0]) # Connect to the scheduler. - local_scheduler_client = local_scheduler.LocalSchedulerClient(local_scheduler_name, NIL_ACTOR_ID) + local_scheduler_client = local_scheduler.LocalSchedulerClient( + local_scheduler_name, NIL_ACTOR_ID, False) self.local_scheduler_clients.append(local_scheduler_client) self.local_scheduler_pids.append(p4) diff --git a/python/ray/local_scheduler/test/test.py b/python/ray/local_scheduler/test/test.py index 2c3f09bb0..a89e84181 100644 --- a/python/ray/local_scheduler/test/test.py +++ b/python/ray/local_scheduler/test/test.py @@ -41,7 +41,8 @@ class TestLocalSchedulerClient(unittest.TestCase): # Start a local scheduler. scheduler_name, self.p2 = local_scheduler.start_local_scheduler(plasma_store_name, use_valgrind=USE_VALGRIND) # Connect to the scheduler. - self.local_scheduler_client = local_scheduler.LocalSchedulerClient(scheduler_name, NIL_ACTOR_ID) + self.local_scheduler_client = local_scheduler.LocalSchedulerClient( + scheduler_name, NIL_ACTOR_ID, False) def tearDown(self): # Check that the processes are still alive. diff --git a/python/ray/monitor.py b/python/ray/monitor.py index a563cd1ea..3d524ecd9 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -18,21 +18,28 @@ from ray.core.generated.TaskReply import TaskReply # These variables must be kept in sync with the C codebase. # common/common.h +HEARTBEAT_TIMEOUT_MILLISECONDS = 100 +NUM_HEARTBEATS_TIMEOUT = 100 DB_CLIENT_ID_SIZE = 20 NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE # common/task.h TASK_STATUS_LOST = 32 +# common/state/redis.cc +PLASMA_MANAGER_HEARTBEAT_CHANNEL = b"plasma_managers" # common/redis_module/ray_redis_module.cc TASK_PREFIX = "TT:" +OBJECT_PREFIX = "OL:" DB_CLIENT_PREFIX = "CL:" DB_CLIENT_TABLE_NAME = b"db_clients" # local_scheduler/local_scheduler.h -LOCAL_SCHEDULER_HEARTBEAT_TIMEOUT_MILLISECONDS = 100 LOCAL_SCHEDULER_CLIENT_TYPE = b"local_scheduler" +# plasma/plasma_manager.cc +PLASMA_MANAGER_CLIENT_TYPE = b"plasma_manager" # Set up logging. logging.basicConfig() log = logging.getLogger() +log.setLevel(logging.WARN) class Monitor(object): """A monitor for Ray processes. @@ -45,69 +52,45 @@ class Monitor(object): redis: A connection to the Redis server. subscribe_client: A pubsub client for the Redis server. This is used to receive notifications about failed components. - local_schedulers: A set of the local scheduler IDs of all of the currently - live local schedulers in the cluster. In addition, this also includes - NIL_ID. + subscribed: A dictionary mapping channel names (str) to whether or not the + subscription to that channel has succeeded yet (bool). + dead_local_schedulers: A set of the local scheduler IDs of all of the local + schedulers that were up at one point and have died since then. + live_plasma_managers: A counter mapping live plasma manager IDs to the + number of heartbeats that have passed since we last heard from that + plasma manager. A plasma manager is live if we received a heartbeat from + it at any point, and if it has not timed out. + dead_plasma_managers: A set of the plasma manager IDs of all the plasma + managers that were up at one point and have died since then. """ def __init__(self, redis_address, redis_port): + # Initialize the Redis clients. self.redis = redis.StrictRedis(host=redis_address, port=redis_port, db=0) self.subscribe_client = self.redis.pubsub() - + self.subscribed = {} # Initialize data structures to keep track of the active database clients. - self.local_schedulers = set() - # Add the NIL_ID so that we don't accidentally mark tasks that aren't - # associated with a node as LOST during cleanup. - self.local_schedulers.add(NIL_ID) + self.dead_local_schedulers = set() + self.live_plasma_managers = Counter() + self.dead_plasma_managers = set() - def subscribe(self): - """Subscribe to the db_clients channel. + def subscribe(self, channel): + """Subscribe to the given channel. + + Args: + channel (str): The channel to subscribe to. Raises: Exception: An exception is raised if the subscription fails. """ - self.subscribe_client.subscribe(DB_CLIENT_TABLE_NAME) - # Wait for the first message to signal that the subscription was successful. - while True: - message = self.subscribe_client.get_message() - if message is None: - time.sleep(LOCAL_SCHEDULER_HEARTBEAT_TIMEOUT_MILLISECONDS / 1000) - continue - break - - # The first message's payload should be the index of our subscription. - if "data" not in message: - Exception("Unable to subscribe to local scheduler table.") - - def read_message(self): - """Read a message from the db_clients channel. - - Returns: - None if no message was to read. Otherwise, a tuple of (db_client_id, - client_type, auxiliary_address, is_insertion) is returned. The value - is_insertion is a bool that is true if the update to the db_clients - table was an insertion and false if deletion. - """ - message = self.subscribe_client.get_message() - if message is None: - return None - - # Parse the message. - data = message["data"] - - notification_object = SubscribeToDBClientTableReply.GetRootAsSubscribeToDBClientTableReply(data, 0) - db_client_id = notification_object.DbClientId() - client_type = notification_object.ClientType() - auxiliary_address = notification_object.AuxAddress() - is_insertion = notification_object.IsInsertion() - - return db_client_id, client_type, auxiliary_address, is_insertion + self.subscribe_client.subscribe(channel) + self.subscribed[channel] = False def cleanup_task_table(self): - """Clean up global state for a failed local schedulers. + """Clean up global state for failed local schedulers. This marks any tasks that were scheduled on dead local schedulers as - TASK_STATUS_LOST. A local scheduler is deemed dead if it is not in - self.local_schedulers. + TASK_STATUS_LOST. A local scheduler is deemed dead if it is in + self.dead_local_schedulers. """ task_ids = self.redis.scan_iter(match="{prefix}*".format(prefix=TASK_PREFIX)) num_tasks_updated = 0 @@ -118,29 +101,146 @@ class Monitor(object): task_object = TaskReply.GetRootAsTaskReply(response, 0) local_scheduler_id = task_object.LocalSchedulerId() # See if the corresponding local scheduler is alive. - if local_scheduler_id not in self.local_schedulers: - num_tasks_updated += 1 + if local_scheduler_id in self.dead_local_schedulers: + # If the task is scheduled on a dead local scheduler, mark the task as + # lost. ok = self.redis.execute_command("RAY.TASK_TABLE_UPDATE", task_id, TASK_STATUS_LOST, NIL_ID) if ok != b"OK": log.warn("Failed to update lost task for dead scheduler.") + num_tasks_updated += 1 if num_tasks_updated > 0: log.warn("Marked {} tasks as lost.".format(num_tasks_updated)) + def cleanup_object_table(self): + """Clean up global state for failed plasma managers. + + This removes dead plasma managers from any location entries in the object + table. A plasma manager is deemed dead if it is in + self.dead_plasma_managers. + """ + # TODO(swang): Also kill the associated plasma store, since it's no longer + # reachable without a plasma manager. + object_ids = self.redis.scan_iter(match="{prefix}*".format(prefix=OBJECT_PREFIX)) + num_objects_removed = 0 + for object_id in object_ids: + object_id = object_id[len(OBJECT_PREFIX):] + managers = self.redis.execute_command("RAY.OBJECT_TABLE_LOOKUP", object_id) + for manager in managers: + if manager in self.dead_plasma_managers: + # If the object was on a dead plasma manager, remove that location + # entry. + ok = self.redis.execute_command("RAY.OBJECT_TABLE_REMOVE", object_id, + manager) + if ok != b"OK": + log.warn("Failed to remove object location for dead plasma " + "manager.") + num_objects_removed += 1 + if num_objects_removed > 0: + log.warn("Marked {} objects as lost.".format(num_objects_removed)) + def scan_db_client_table(self): - """Scan the database client table for the current clients. + """Scan the database client table for dead clients. After subscribing to the client table, it's necessary to call this before - reading any messages from the subscription channel. + reading any messages from the subscription channel. This ensures that we do + not miss any notifications for deleted clients that occurred before we + subscribed. """ db_client_keys = self.redis.keys("{prefix}*".format(prefix=DB_CLIENT_PREFIX)) for db_client_key in db_client_keys: db_client_id = db_client_key[len(DB_CLIENT_PREFIX):] - client_type = self.redis.hget(db_client_key, "client_type") - if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: - self.local_schedulers.add(db_client_id) + client_type, deleted = self.redis.hmget(db_client_key, + [b"client_type", b"deleted"]) + deleted = bool(int(deleted)) + if deleted: + if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: + self.dead_local_schedulers.add(db_client_id) + elif client_type == PLASMA_MANAGER_CLIENT_TYPE: + self.dead_plasma_managers.add(db_client_id) + + def subscribe_handler(self, channel, data): + """Handle a subscription success message from Redis. + """ + log.debug("Subscribed to {}, data was {}".format(channel, data)) + self.subscribed[channel] = True + + def db_client_notification_handler(self, channel, data): + """Handle a notification from the db_client table from Redis. + + This handler processes notifications from the db_client table. + Notifications should be parsed using the SubscribeToDBClientTableReply + flatbuffer. Deletions are processed, insertions are ignored. Cleanup of the + associated state in the state tables should be handled by the caller. + """ + notification_object = SubscribeToDBClientTableReply.GetRootAsSubscribeToDBClientTableReply(data, 0) + db_client_id = notification_object.DbClientId() + client_type = notification_object.ClientType() + auxiliary_address = notification_object.AuxAddress() + is_insertion = notification_object.IsInsertion() + + # If the update was an insertion, we ignore it. + if is_insertion: + return + + # If the update was a deletion, add them to our accounting for dead + # local schedulers and plasma managers. + log.warn("Removed {}".format(client_type)) + if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: + if db_client_id not in self.dead_local_schedulers: + self.dead_local_schedulers.add(db_client_id) + elif client_type == PLASMA_MANAGER_CLIENT_TYPE: + if db_client_id not in self.dead_plasma_managers: + self.dead_plasma_managers.add(db_client_id) + + def plasma_manager_heartbeat_handler(self, channel, data): + """Handle a plasma manager heartbeat from Redis. + + This resets the number of heartbeats that we've missed from this plasma + manager. + """ + # The first DB_CLIENT_ID_SIZE characters are the client ID. + db_client_id = data[:DB_CLIENT_ID_SIZE] + # Reset the number of heartbeats that we've missed from this plasma + # manager. + self.live_plasma_managers[db_client_id] = 0 + + def process_messages(self): + """Process all messages ready in the subscription channels. + + This reads messages from the subscription channels and calls the + appropriate handlers until there are no messages left. + """ + while True: + message = self.subscribe_client.get_message() + if message is None: + return + + # Parse the message. + channel = message["channel"] + data = message["data"] + + # Determine the appropriate message handler. + message_handler = None + if not self.subscribed[channel]: + # If the data was an integer, then the message was a response to an + # initial subscription request. + is_subscribe = int(data) + message_handler = self.subscribe_handler + elif channel == PLASMA_MANAGER_HEARTBEAT_CHANNEL: + assert(self.subscribed[channel]) + # The message was a heartbeat from a plasma manager. + message_handler = self.plasma_manager_heartbeat_handler + elif channel == DB_CLIENT_TABLE_NAME: + assert(self.subscribed[channel]) + # The message was a notification from the db_client table. + message_handler = self.db_client_notification_handler + + # Call the handler. + assert(message_handler is not None) + message_handler(channel, data) def run(self): """Run the monitor. @@ -149,39 +249,64 @@ class Monitor(object): clients and cleaning up state accordingly. """ # Initialize the subscription channel. - self.subscribe() + self.subscribe(DB_CLIENT_TABLE_NAME) + self.subscribe(PLASMA_MANAGER_HEARTBEAT_CHANNEL) - # Scan the database table and clean up any state associated with clients - # not in the database table. NOTE: This must be called before reading any - # messages from the subscription channel. This ensures that we start in a - # consistent state, since we may have missed notifications that were sent - # before we connected to the subscription channel. + # Scan the database table for dead database clients. NOTE: This must be + # called before reading any messages from the subscription channel. This + # ensures that we start in a consistent state, since we may have missed + # notifications that were sent before we connected to the subscription + # channel. self.scan_db_client_table() - self.cleanup_task_table() - log.debug("Scanned schedulers: {}".format(self.local_schedulers)) + # If there were any dead clients at startup, clean up the associated state + # in the state tables. + if len(self.dead_local_schedulers) > 0: + self.cleanup_task_table() + if len(self.dead_plasma_managers) > 0: + self.cleanup_object_table() + log.debug("{} dead local schedulers, {} plasma " + "managers total, {} dead plasma managers".format( + len(self.dead_local_schedulers), + len(self.live_plasma_managers) + len(self.dead_plasma_managers), + len(self.dead_plasma_managers) + )) - # Read messages from the subscription channel. + # Handle messages from the subscription channels. while True: - time.sleep(LOCAL_SCHEDULER_HEARTBEAT_TIMEOUT_MILLISECONDS / 1000) - client = self.read_message() - # There was no message to be read. - if client is None: - continue + # Record how many dead local schedulers and plasma managers we had at the + # beginning of this round. + num_dead_local_schedulers = len(self.dead_local_schedulers) + num_dead_plasma_managers = len(self.dead_plasma_managers) + # Process a round of messages. + self.process_messages() + # If any new local schedulers or plasma managers were marked as dead in + # this round, clean up the associated state. + if len(self.dead_local_schedulers) > num_dead_local_schedulers: + self.cleanup_task_table() + if len(self.dead_plasma_managers) > num_dead_plasma_managers: + self.cleanup_object_table() - db_client_id, client_type, auxiliary_address, is_insertion = client + # Handle plasma managers that timed out during this round. + plasma_manager_ids = list(self.live_plasma_managers.keys()) + for plasma_manager_id in plasma_manager_ids: + if self.live_plasma_managers[plasma_manager_id] >= NUM_HEARTBEATS_TIMEOUT: + log.warn("Timed out {}".format(PLASMA_MANAGER_CLIENT_TYPE)) + # Remove the plasma manager from the managers whose heartbeats we're + # tracking. + del self.live_plasma_managers[plasma_manager_id] + # Remove the plasma manager from the db_client table. The + # corresponding state in the object table will be cleaned up once we + # receive the notification for this db_client deletion. + self.redis.execute_command("RAY.DISCONNECT", plasma_manager_id) - # If the update was an insertion, record the client ID. - if is_insertion: - self.local_schedulers.add(db_client_id) - log.debug("Added scheduler: {}".format(db_client_id)) - continue + # Increment the number of heartbeats that we've missed from each plasma manager. + for plasma_manager_id in self.live_plasma_managers: + self.live_plasma_managers[plasma_manager_id] += 1 + + # Wait for a heartbeat interval before processing the next round of + # messages. + time.sleep(HEARTBEAT_TIMEOUT_MILLISECONDS * 1e-3) - # If the update was a deletion, clean up global state. - if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: - if db_client_id in self.local_schedulers: - log.warn("Removed scheduler: {}".format(db_client_id)) - self.local_schedulers.remove(db_client_id) - self.cleanup_task_table() if __name__ == "__main__": parser = argparse.ArgumentParser(description=("Parse Redis server for the " diff --git a/python/ray/plasma/test/test.py b/python/ray/plasma/test/test.py index 79a554df5..a7c833057 100644 --- a/python/ray/plasma/test/test.py +++ b/python/ray/plasma/test/test.py @@ -699,11 +699,21 @@ class TestPlasmaManager(unittest.TestCase): self.assertEqual(set(waiting), set(object_ids_perm[(i + 1):])) def test_transfer(self): + num_attempts = 100 for _ in range(100): # Create an object. object_id1, memory_buffer1, metadata1 = create_object(self.client1, 2000, 2000) - # Transfer the buffer to the the other PlasmaStore. - self.client1.transfer("127.0.0.1", self.port2, object_id1) + # Transfer the buffer to the the other Plasma store. There is a race + # condition on the create and transfer of the object, so keep trying + # until the object appears on the second Plasma store. + for i in range(num_attempts): + self.client1.transfer("127.0.0.1", self.port2, object_id1) + buff = self.client2.get([object_id1], timeout_ms=100)[0] + if buff is not None: + break + self.assertNotEqual(buff, None) + del buff + # Compare the two buffers. assert_get_object_equal(self, self.client1, self.client2, object_id1, memory_buffer=memory_buffer1, metadata=metadata1) @@ -715,8 +725,17 @@ class TestPlasmaManager(unittest.TestCase): # Create an object. object_id2, memory_buffer2, metadata2 = create_object(self.client2, 20000, 20000) - # Transfer the buffer to the the other PlasmaStore. - self.client2.transfer("127.0.0.1", self.port1, object_id2) + # Transfer the buffer to the the other Plasma store. There is a race + # condition on the create and transfer of the object, so keep trying + # until the object appears on the second Plasma store. + for i in range(num_attempts): + self.client2.transfer("127.0.0.1", self.port1, object_id2) + buff = self.client1.get([object_id2], timeout_ms=100)[0] + if buff is not None: + break + self.assertNotEqual(buff, None) + del buff + # Compare the two buffers. assert_get_object_equal(self, self.client1, self.client2, object_id2, memory_buffer=memory_buffer2, metadata=metadata2) @@ -761,7 +780,9 @@ class TestPlasmaManagerRecovery(unittest.TestCase): # Store the processes that will be explicitly killed during tearDown so # that a test case can remove ones that will be killed during the test. - self.processes_to_kill = [self.p2, self.p3] + # NOTE: The plasma managers must be killed before the plasma store since + # plasma store death will bring down the managers. + self.processes_to_kill = [self.p3, self.p2] def tearDown(self): # Check that the processes are still alive. @@ -798,7 +819,7 @@ class TestPlasmaManagerRecovery(unittest.TestCase): # Start a second plasma manager attached to the same store. manager_name, self.p5, self.port2 = plasma.start_plasma_manager(self.store_name, self.redis_address, use_valgrind=USE_VALGRIND) - self.processes_to_kill.append(self.p5) + self.processes_to_kill = [self.p5] + self.processes_to_kill # Check that the second manager knows about existing objects. client2 = plasma.PlasmaClient(self.store_name, manager_name) diff --git a/python/ray/worker.py b/python/ray/worker.py index 82c55f4b8..5700b23a3 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1238,10 +1238,6 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, a redis_ip_address, redis_port = info["redis_address"].split(":") worker.redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) worker.lock = threading.Lock() - # Create an object store client. - worker.plasma_client = ray.plasma.PlasmaClient(info["store_socket_name"], info["manager_socket_name"]) - # Create the local scheduler client. - worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient(info["local_scheduler_socket_name"], worker.actor_id) # Register the worker with Redis. if mode in [SCRIPT_MODE, SILENT_MODE]: # The concept of a driver is the same as the concept of a "job". Register @@ -1255,6 +1251,7 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, a "local_scheduler_socket": info["local_scheduler_socket_name"]} driver_info["name"] = main.__file__ if hasattr(main, "__file__") else "INTERACTIVE MODE" worker.redis_client.hmset(b"Drivers:" + worker.worker_id, driver_info) + is_worker = False elif mode == WORKER_MODE: # Register the worker with Redis. worker.redis_client.hmset(b"Workers:" + worker.worker_id, @@ -1262,8 +1259,15 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, a "plasma_store_socket": info["store_socket_name"], "plasma_manager_socket": info["manager_socket_name"], "local_scheduler_socket": info["local_scheduler_socket_name"]}) + is_worker = True else: raise Exception("This code should be unreachable.") + + # Create an object store client. + worker.plasma_client = ray.plasma.PlasmaClient(info["store_socket_name"], info["manager_socket_name"]) + # Create the local scheduler client. + worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient(info["local_scheduler_socket_name"], worker.actor_id, is_worker) + # If this is a driver, set the current task ID, the task driver ID, and set # the task index to 0. if mode in [SCRIPT_MODE, SILENT_MODE]: diff --git a/scripts/start_ray.py b/scripts/start_ray.py index d49382aa4..feb61a567 100644 --- a/scripts/start_ray.py +++ b/scripts/start_ray.py @@ -30,6 +30,14 @@ def check_no_existing_redis_clients(node_ip_address, redis_address): assert b"ray_client_id" in info assert b"node_ip_address" in info assert b"client_type" in info + assert b"deleted" in info + # Clients that ran on the same node but that are marked dead can be + # ignored. + deleted = info[b"deleted"] + deleted = bool(int(deleted)) + if deleted: + continue + if info[b"node_ip_address"].decode("ascii") == node_ip_address: raise Exception("This Redis instance is already connected to clients with this IP address.") diff --git a/src/common/common.h b/src/common/common.h index 7b04ea19b..f5aa17b7c 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -20,6 +20,14 @@ extern "C" { } #endif +/** The duration between heartbeats. These are sent by the plasma manager and + * local scheduler. */ +#define HEARTBEAT_TIMEOUT_MILLISECONDS 100 +/** If a component has not sent a heartbeat in the last NUM_HEARTBEATS_TIMEOUT + * heartbeat intervals, the global scheduler or monitor process will report it + * as dead to the db_client table. */ +#define NUM_HEARTBEATS_TIMEOUT 100 + /** Definitions for Ray logging levels. */ #define RAY_COMMON_DEBUG 0 #define RAY_COMMON_INFO 1 diff --git a/src/common/io.cc b/src/common/io.cc index 0c9dd8b03..c4d2679ee 100644 --- a/src/common/io.cc +++ b/src/common/io.cc @@ -193,6 +193,7 @@ int connect_inet_sock(const char *ip_addr, int port) { struct hostent *manager = gethostbyname(ip_addr); /* TODO(pcm): cache this */ if (!manager) { LOG_ERROR("Failed to get hostname from address %s:%d.", ip_addr, port); + close(fd); return -1; } @@ -203,6 +204,7 @@ int connect_inet_sock(const char *ip_addr, int port) { if (connect(fd, (struct sockaddr *) &addr, sizeof(addr)) != 0) { LOG_ERROR("Connection to socket failed for address %s:%d.", ip_addr, port); + close(fd); return -1; } return fd; diff --git a/src/common/redis_module/ray_redis_module.cc b/src/common/redis_module/ray_redis_module.cc index f83b2c64b..c9fb2194f 100644 --- a/src/common/redis_module/ray_redis_module.cc +++ b/src/common/redis_module/ray_redis_module.cc @@ -73,6 +73,12 @@ flatbuffers::Offset RedisStringToFlatbuf( * Publish a notification to a client's notification channel about an insertion * or deletion to the db client table. * + * TODO(swang): Use flatbuffers for the notification message. + * The format for the published notification is: + * : + * If no auxiliary address is provided, aux_address will be set to ":". If + * is_insertion is true, then the last field will be "1", else "0". + * * @param ctx The Redis context. * @param ray_client_id The ID of the database client that was inserted or * deleted. @@ -159,14 +165,20 @@ int Connect_RedisCommand(RedisModuleCtx *ctx, RedisModuleKey *db_client_table_key = OpenPrefixedKey(ctx, DB_CLIENT_PREFIX, ray_client_id, REDISMODULE_WRITE); + if (RedisModule_KeyType(db_client_table_key) != REDISMODULE_KEYTYPE_EMPTY) { + return RedisModule_ReplyWithError(ctx, "Client already exists"); + } + /* This will be used to construct a publish message. */ RedisModuleString *aux_address = NULL; RedisModuleString *aux_address_key = RedisModule_CreateString(ctx, "aux_address", strlen("aux_address")); + RedisModuleString *deleted = RedisModule_CreateString(ctx, "0", strlen("0")); RedisModule_HashSet(db_client_table_key, REDISMODULE_HASH_CFIELDS, "ray_client_id", ray_client_id, "node_ip_address", - node_ip_address, "client_type", client_type, NULL); + node_ip_address, "client_type", client_type, "deleted", + deleted, NULL); for (int i = 4; i < argc; i += 2) { RedisModuleString *key = argv[i]; @@ -178,6 +190,7 @@ int Connect_RedisCommand(RedisModuleCtx *ctx, } } /* Clean up. */ + RedisModule_FreeString(ctx, deleted); RedisModule_FreeString(ctx, aux_address_key); RedisModule_CloseKey(db_client_table_key); if (!PublishDBClientNotification(ctx, ray_client_id, client_type, aux_address, @@ -213,32 +226,47 @@ int Disconnect_RedisCommand(RedisModuleCtx *ctx, /* Get the client type. */ RedisModuleKey *db_client_table_key = OpenPrefixedKey(ctx, DB_CLIENT_PREFIX, ray_client_id, REDISMODULE_WRITE); - if (RedisModule_KeyType(db_client_table_key) == REDISMODULE_KEYTYPE_EMPTY) { - /* Someone else already deleted this client. */ + + RedisModuleString *deleted_string; + RedisModule_HashGet(db_client_table_key, REDISMODULE_HASH_CFIELDS, "deleted", + &deleted_string, NULL); + long long deleted; + int parsed = RedisModule_StringToLongLong(deleted_string, &deleted); + RedisModule_FreeString(ctx, deleted_string); + if (parsed != REDISMODULE_OK) { RedisModule_CloseKey(db_client_table_key); - RedisModule_ReplyWithSimpleString(ctx, "OK"); - return REDISMODULE_OK; + return RedisModule_ReplyWithError(ctx, "Unable to parse deleted field"); } - RedisModuleString *client_type; - RedisModuleString *aux_address; - RedisModule_HashGet(db_client_table_key, REDISMODULE_HASH_CFIELDS, - "client_type", &client_type, "aux_address", &aux_address, - NULL); + bool published = true; + if (deleted == 0) { + /* Remove the client from the client table. */ + RedisModuleString *deleted = + RedisModule_CreateString(ctx, "1", strlen("1")); + RedisModule_HashSet(db_client_table_key, REDISMODULE_HASH_CFIELDS, + "deleted", deleted, NULL); + RedisModule_FreeString(ctx, deleted); + + RedisModuleString *client_type; + RedisModuleString *aux_address; + RedisModule_HashGet(db_client_table_key, REDISMODULE_HASH_CFIELDS, + "client_type", &client_type, "aux_address", + &aux_address, NULL); + + /* Publish the deletion notification on the db client channel. */ + published = PublishDBClientNotification(ctx, ray_client_id, client_type, + aux_address, false); + if (aux_address != NULL) { + RedisModule_FreeString(ctx, aux_address); + } + RedisModule_FreeString(ctx, client_type); + } - /* Remove the client from the client table. */ - CHECK_ERROR(RedisModule_DeleteKey(db_client_table_key), - "Unable to delete db client key."); RedisModule_CloseKey(db_client_table_key); - /* Publish the deletion notification on the db client channel. */ - bool published = PublishDBClientNotification(ctx, ray_client_id, client_type, - aux_address, false); - - RedisModule_FreeString(ctx, aux_address); - RedisModule_FreeString(ctx, client_type); - if (!published) { + /* Return an error message if we weren't able to publish the deletion + * notification. */ return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); } diff --git a/src/common/state/db_client_table.cc b/src/common/state/db_client_table.cc index dbf75cba8..50837ac47 100644 --- a/src/common/state/db_client_table.cc +++ b/src/common/state/db_client_table.cc @@ -27,3 +27,14 @@ void db_client_table_subscribe( (table_done_callback) done_callback, redis_db_client_table_subscribe, user_context); } + +void plasma_manager_send_heartbeat(DBHandle *db_handle) { + RetryInfo heartbeat_retry; + heartbeat_retry.num_retries = 0; + heartbeat_retry.timeout = HEARTBEAT_TIMEOUT_MILLISECONDS; + heartbeat_retry.fail_callback = NULL; + + init_table_callback(db_handle, NIL_ID, __func__, NULL, + (RetryInfo *) &heartbeat_retry, NULL, + redis_plasma_manager_send_heartbeat, NULL); +} diff --git a/src/common/state/db_client_table.h b/src/common/state/db_client_table.h index c5fa4cf50..c969104b2 100644 --- a/src/common/state/db_client_table.h +++ b/src/common/state/db_client_table.h @@ -65,4 +65,20 @@ typedef struct { void *subscribe_context; } DBClientTableSubscribeData; +/* + * ==== Plasma manager heartbeats ==== + */ + +/** + * Start sending heartbeats to the plasma_managers channel. Each + * heartbeat contains this database client's ID. Heartbeats can be subscribed + * to through the plasma_managers channel. Once called, this "retries" the + * heartbeat operation forever, every HEARTBEAT_TIMEOUT_MILLISECONDS + * milliseconds. + * + * @param db_handle Database handle. + * @return Void. + */ +void plasma_manager_send_heartbeat(DBHandle *db_handle); + #endif /* DB_CLIENT_TABLE_H */ diff --git a/src/common/state/redis.cc b/src/common/state/redis.cc index 23d2b90bc..01a2e1264 100644 --- a/src/common/state/redis.cc +++ b/src/common/state/redis.cc @@ -1068,6 +1068,23 @@ void redis_local_scheduler_table_send_info(TableCallbackData *callback_data) { } } +void redis_plasma_manager_send_heartbeat(TableCallbackData *callback_data) { + DBHandle *db = callback_data->db_handle; + /* NOTE(swang): We purposefully do not provide a callback, leaving the table + * operation and timer active. This allows us to send a new heartbeat every + * HEARTBEAT_TIMEOUT_MILLISECONDS without having to allocate and deallocate + * memory for callback data each time. */ + int status = redisAsyncCommand( + db->context, NULL, (void *) callback_data->timer_id, + "PUBLISH plasma_managers %b", db->client.id, sizeof(db->client.id)); + if ((status == REDIS_ERR) || db->context->err) { + LOG_REDIS_DEBUG(db->context, + "error in redis_plasma_manager_send_heartbeat"); + } + /* Clean up the timer and callback. */ + destroy_timer_callback(db->loop, callback_data); +} + void redis_actor_notification_table_subscribe_callback(redisAsyncContext *c, void *r, void *privdata) { diff --git a/src/common/state/redis.h b/src/common/state/redis.h index 1490f5b33..e8af9c643 100644 --- a/src/common/state/redis.h +++ b/src/common/state/redis.h @@ -253,6 +253,8 @@ void redis_local_scheduler_table_subscribe(TableCallbackData *callback_data); */ void redis_local_scheduler_table_send_info(TableCallbackData *callback_data); +void redis_plasma_manager_send_heartbeat(TableCallbackData *callback_data); + /** * Subscribe to updates about newly created actors. * diff --git a/src/global_scheduler/global_scheduler.cc b/src/global_scheduler/global_scheduler.cc index 457bea28a..6df8ea73a 100644 --- a/src/global_scheduler/global_scheduler.cc +++ b/src/global_scheduler/global_scheduler.cc @@ -386,6 +386,11 @@ int task_cleanup_handler(event_loop *loop, timer_id id, void *context) { } } + return GLOBAL_SCHEDULER_TASK_CLEANUP_MILLISECONDS; +} + +int heartbeat_timeout_handler(event_loop *loop, timer_id id, void *context) { + GlobalSchedulerState *state = (GlobalSchedulerState *) context; /* Check for local schedulers that have missed a number of heartbeats. If any * local schedulers have died, notify others so that the state can be cleaned * up. */ @@ -395,8 +400,7 @@ int task_cleanup_handler(event_loop *loop, timer_id id, void *context) { for (int i = utarray_len(state->local_schedulers) - 1; i >= 0; --i) { local_scheduler_ptr = (LocalScheduler *) utarray_eltptr(state->local_schedulers, i); - if (local_scheduler_ptr->num_heartbeats_missed >= - GLOBAL_SCHEDULER_HEARTBEAT_TIMEOUT) { + if (local_scheduler_ptr->num_heartbeats_missed >= NUM_HEARTBEATS_TIMEOUT) { LOG_WARN( "Missed too many heartbeats from local scheduler, marking as dead."); /* Notify others by updating the global state. */ @@ -409,7 +413,7 @@ int task_cleanup_handler(event_loop *loop, timer_id id, void *context) { } /* Reset the timer. */ - return GLOBAL_SCHEDULER_TASK_CLEANUP_MILLISECONDS; + return HEARTBEAT_TIMEOUT_MILLISECONDS; } void start_server(const char *redis_addr, int redis_port) { @@ -442,6 +446,8 @@ void start_server(const char *redis_addr, int redis_port) { * timer should notice and schedule the task. */ event_loop_add_timer(loop, GLOBAL_SCHEDULER_TASK_CLEANUP_MILLISECONDS, task_cleanup_handler, g_state); + event_loop_add_timer(loop, HEARTBEAT_TIMEOUT_MILLISECONDS, + heartbeat_timeout_handler, g_state); /* Start the event loop. */ event_loop_run(loop); } diff --git a/src/global_scheduler/global_scheduler.h b/src/global_scheduler/global_scheduler.h index 6f6e691ea..80a70e94a 100644 --- a/src/global_scheduler/global_scheduler.h +++ b/src/global_scheduler/global_scheduler.h @@ -11,10 +11,6 @@ /* The frequency with which the global scheduler checks if there are any tasks * that haven't been scheduled yet. */ #define GLOBAL_SCHEDULER_TASK_CLEANUP_MILLISECONDS 100 -/* If a local scheduler has not sent a heartbeat in the last - * GLOBAL_SCHEDULER_HEARTBEAT_TIMEOUT heartbeat intervals, we will report it - * dead to the db_client table. */ -#define GLOBAL_SCHEDULER_HEARTBEAT_TIMEOUT 100 /** Contains all information that is associated with a local scheduler. */ typedef struct { diff --git a/src/local_scheduler/local_scheduler.cc b/src/local_scheduler/local_scheduler.cc index 7ede05578..1f1011558 100644 --- a/src/local_scheduler/local_scheduler.cc +++ b/src/local_scheduler/local_scheduler.cc @@ -146,6 +146,11 @@ void kill_worker(LocalSchedulerClient *worker, bool cleanup) { } void LocalSchedulerState_free(LocalSchedulerState *state) { + /* Reset the SIGTERM handler to default behavior, so we try to clean up the + * local scheduler at most once. If a SIGTERM is caught afterwards, there is + * the possibility of orphan worker processes. */ + signal(SIGTERM, SIG_DFL); + /* Free the command for starting new workers. */ if (state->config.start_worker_command != NULL) { int i = 0; @@ -471,7 +476,10 @@ void process_plasma_notification(event_loop *loop, /* Read the notification from Plasma. */ uint8_t *notification = read_message_async(loop, client_sock); if (!notification) { - return; + /* The store has closed the socket. */ + LocalSchedulerState_free(state); + LOG_FATAL( + "Lost connection to the plasma store, local scheduler is exiting!"); } auto object_info = flatbuffers::GetRoot(notification); ObjectID object_id = from_flatbuf(object_info->object_id()); @@ -773,6 +781,10 @@ LocalSchedulerState *g_state; void signal_handler(int signal) { LOG_DEBUG("Signal was %d", signal); if (signal == SIGTERM) { + /* NOTE(swang): This call removes the SIGTERM handler to ensure that we + * free the local scheduler state at most once. If another SIGTERM is + * caught during this call, there is the possibility of orphan worker + * processes. */ LocalSchedulerState_free(g_state); exit(0); } @@ -842,7 +854,7 @@ int heartbeat_handler(event_loop *loop, timer_id id, void *context) { /* Publish the heartbeat to all subscribers of the local scheduler table. */ local_scheduler_table_send_info(state->db, &info, NULL); /* Reset the timer. */ - return LOCAL_SCHEDULER_HEARTBEAT_TIMEOUT_MILLISECONDS; + return HEARTBEAT_TIMEOUT_MILLISECONDS; } void start_server(const char *node_ip_address, @@ -887,7 +899,7 @@ void start_server(const char *node_ip_address, * scheduler to the local scheduler table. This message also serves as a * heartbeat. */ if (g_state->db != NULL) { - event_loop_add_timer(loop, LOCAL_SCHEDULER_HEARTBEAT_TIMEOUT_MILLISECONDS, + event_loop_add_timer(loop, HEARTBEAT_TIMEOUT_MILLISECONDS, heartbeat_handler, g_state); } /* Create a timer for fetching queued tasks' missing object dependencies. */ diff --git a/src/local_scheduler/local_scheduler.h b/src/local_scheduler/local_scheduler.h index 9918f3ea9..767388742 100644 --- a/src/local_scheduler/local_scheduler.h +++ b/src/local_scheduler/local_scheduler.h @@ -4,9 +4,6 @@ #include "task.h" #include "event_loop.h" -/* The duration between local scheduler heartbeats. */ -#define LOCAL_SCHEDULER_HEARTBEAT_TIMEOUT_MILLISECONDS 100 - /* The duration that we wait after sending a worker SIGTERM before sending the * worker SIGKILL. */ #define KILL_WORKER_TIMEOUT_MILLISECONDS 100 diff --git a/src/local_scheduler/local_scheduler_client.cc b/src/local_scheduler/local_scheduler_client.cc index bc4bf554a..323f17731 100644 --- a/src/local_scheduler/local_scheduler_client.cc +++ b/src/local_scheduler/local_scheduler_client.cc @@ -9,18 +9,26 @@ LocalSchedulerConnection *LocalSchedulerConnection_init( const char *local_scheduler_socket, - ActorID actor_id) { + ActorID actor_id, + bool is_worker) { LocalSchedulerConnection *result = (LocalSchedulerConnection *) malloc(sizeof(LocalSchedulerConnection)); result->conn = connect_ipc_sock_retry(local_scheduler_socket, -1, -1); - flatbuffers::FlatBufferBuilder fbb; - auto message = - CreateRegisterWorkerInfo(fbb, to_flatbuf(fbb, actor_id), getpid()); - fbb.Finish(message); - /* Register the process ID with the local scheduler. */ - int success = write_message(result->conn, MessageType_RegisterWorkerInfo, - fbb.GetSize(), fbb.GetBufferPointer()); - CHECKM(success == 0, "Unable to register worker with local scheduler"); + + if (is_worker) { + /* If we are a worker, register with the local scheduler. + * NOTE(swang): If the local scheduler exits and we are registered as a + * worker, we will get killed. */ + flatbuffers::FlatBufferBuilder fbb; + auto message = + CreateRegisterWorkerInfo(fbb, to_flatbuf(fbb, actor_id), getpid()); + fbb.Finish(message); + /* Register the process ID with the local scheduler. */ + int success = write_message(result->conn, MessageType_RegisterWorkerInfo, + fbb.GetSize(), fbb.GetBufferPointer()); + CHECKM(success == 0, "Unable to register worker with local scheduler"); + } + return result; } diff --git a/src/local_scheduler/local_scheduler_client.h b/src/local_scheduler/local_scheduler_client.h index 4988bcf4a..014c1861e 100644 --- a/src/local_scheduler/local_scheduler_client.h +++ b/src/local_scheduler/local_scheduler_client.h @@ -17,11 +17,14 @@ typedef struct { * local scheduler. * @param actor_id The ID of the actor running on this worker. If no actor is * running on this actor, this should be NIL_ACTOR_ID. + * @param is_worker Whether this client is a worker. If it is a worker, an + * additional message will be sent to register as one. * @return The connection information. */ LocalSchedulerConnection *LocalSchedulerConnection_init( const char *local_scheduler_socket, - ActorID actor_id); + ActorID actor_id, + bool is_worker); /** * Disconnect from the local scheduler. diff --git a/src/local_scheduler/local_scheduler_extension.cc b/src/local_scheduler/local_scheduler_extension.cc index b0a08dde3..943d82643 100644 --- a/src/local_scheduler/local_scheduler_extension.cc +++ b/src/local_scheduler/local_scheduler_extension.cc @@ -18,13 +18,14 @@ static int PyLocalSchedulerClient_init(PyLocalSchedulerClient *self, PyObject *kwds) { char *socket_name; ActorID actor_id; - if (!PyArg_ParseTuple(args, "sO&", &socket_name, PyStringToUniqueID, - &actor_id)) { + PyObject *is_worker; + if (!PyArg_ParseTuple(args, "sO&O", &socket_name, PyStringToUniqueID, + &actor_id, &is_worker)) { return -1; } /* Connect to the local scheduler. */ - self->local_scheduler_connection = - LocalSchedulerConnection_init(socket_name, actor_id); + self->local_scheduler_connection = LocalSchedulerConnection_init( + socket_name, actor_id, (bool) PyObject_IsTrue(is_worker)); return 0; } diff --git a/src/local_scheduler/test/local_scheduler_tests.cc b/src/local_scheduler/test/local_scheduler_tests.cc index 93312f304..225e9740c 100644 --- a/src/local_scheduler/test/local_scheduler_tests.cc +++ b/src/local_scheduler/test/local_scheduler_tests.cc @@ -103,7 +103,7 @@ LocalSchedulerMock *LocalSchedulerMock_init(int num_workers, sizeof(LocalSchedulerConnection *) * num_mock_workers); for (int i = 0; i < num_mock_workers; ++i) { mock->conns[i] = LocalSchedulerConnection_init( - utstring_body(local_scheduler_socket_name), NIL_ACTOR_ID); + utstring_body(local_scheduler_socket_name), NIL_ACTOR_ID, true); new_client_connection(mock->loop, mock->local_scheduler_fd, (void *) mock->local_scheduler_state, 0); } diff --git a/src/plasma/plasma.cc b/src/plasma/plasma.cc index b964d9e35..83dce6ce1 100644 --- a/src/plasma/plasma.cc +++ b/src/plasma/plasma.cc @@ -7,16 +7,16 @@ #include "plasma_protocol.h" -void warn_if_sigpipe(int status, int client_sock) { +bool warn_if_sigpipe(int status, int client_sock) { if (status >= 0) { - return; + return false; } - if (errno == EPIPE || errno == EBADF) { + if (errno == EPIPE || errno == EBADF || errno == ECONNRESET) { LOG_WARN( "Received SIGPIPE or BAD FILE DESCRIPTOR when sending a message to " "client on fd %d. The client on the other end may have hung up.", client_sock); - return; + return true; } LOG_FATAL("Failed to write message to client on fd %d.", client_sock); } diff --git a/src/plasma/plasma.h b/src/plasma/plasma.h index 472c870ce..f482d4e26 100644 --- a/src/plasma/plasma.h +++ b/src/plasma/plasma.h @@ -132,7 +132,7 @@ typedef struct { * information. * @return Void. */ -void warn_if_sigpipe(int status, int client_sock); +bool warn_if_sigpipe(int status, int client_sock); uint8_t *create_object_info_buffer(ObjectInfoT *object_info); diff --git a/src/plasma/plasma_manager.cc b/src/plasma/plasma_manager.cc index a16cfeafb..dcece41db 100644 --- a/src/plasma/plasma_manager.cc +++ b/src/plasma/plasma_manager.cc @@ -38,6 +38,7 @@ #include "state/object_table.h" #include "state/error_table.h" #include "state/task_table.h" +#include "state/db_client_table.h" /** * Process either the fetch or the status request. @@ -266,9 +267,6 @@ struct ClientConnection { int fd; /** Timer id for timing out wait (or fetch). */ int64_t timer_id; - /** The objects that we are waiting for and their callback - * contexts, for either a fetch or a wait operation. */ - ClientObjectRequest *active_objects; /** The number of objects that we have left to return for * this fetch or wait operation. */ int num_return_objects; @@ -280,6 +278,34 @@ struct ClientConnection { UT_hash_handle manager_hh; }; +/** + * Initializes the state for a plasma client connection. + * + * @param state The plasma manager state. + * @param client_sock The socket that we use to communicate with the client. + * @param client_key A string uniquely identifying the client. If the client is + * another plasma manager, this is the manager's IP address and port. + * Else, the client is the string of the client's socket. + * @return A pointer to the initialized client state. + */ +ClientConnection *ClientConnection_init(PlasmaManagerState *state, + int client_sock, + char *client_key); + +/** + * Destroys a plasma client and its connection. + * + * @param client_conn The client's state. + * @return Void. + */ +void ClientConnection_free(ClientConnection *client_conn); + +void object_table_subscribe_callback(ObjectID object_id, + int64_t data_size, + int manager_count, + const char *manager_vector[], + void *context); + ObjectWaitRequests **object_wait_requests_table_ptr_from_type( PlasmaManagerState *manager_state, int type) { @@ -505,30 +531,10 @@ PlasmaManagerState *PlasmaManagerState_init(const char *store_socket_name, } void PlasmaManagerState_free(PlasmaManagerState *state) { - ClientConnection *manager_conn, *tmp; - HASH_ITER(manager_hh, state->manager_connections, manager_conn, tmp) { - HASH_DELETE(manager_hh, state->manager_connections, manager_conn); - - /* Free the hash table of object IDs that are waiting to be transferred. */ - PlasmaRequestBuffer *request_buffer, *tmp_buffer; - HASH_ITER(hh, manager_conn->pending_object_transfers, request_buffer, - tmp_buffer) { - /* We do not free the PlasmaRequestBuffer here because it is also in the - * transfer queue and will be freed below. */ - HASH_DELETE(hh, manager_conn->pending_object_transfers, request_buffer); - } - - /* Free the transfer queue. */ - PlasmaRequestBuffer *head = manager_conn->transfer_queue; - while (head) { - DL_DELETE(manager_conn->transfer_queue, head); - free(head); - head = manager_conn->transfer_queue; - } - /* Close the manager connection and free the remaining state. */ - close(manager_conn->fd); - free(manager_conn->ip_addr_port); - free(manager_conn); + ClientConnection *manager_conn, *tmp_manager_conn; + HASH_ITER(manager_hh, state->manager_connections, manager_conn, + tmp_manager_conn) { + ClientConnection_free(manager_conn); } if (state->fetch_requests != NULL) { @@ -538,6 +544,12 @@ void PlasmaManagerState_free(PlasmaManagerState *state) { } } + AvailableObject *entry, *tmp_object_entry; + HASH_ITER(hh, state->local_available_objects, entry, tmp_object_entry) { + HASH_DELETE(hh, state->local_available_objects, entry); + free(entry); + } + plasma_disconnect(state->plasma_conn); event_loop_destroy(state->loop); free_protocol_builder(state->builder); @@ -601,9 +613,10 @@ void send_queued_request(event_loop *loop, } PlasmaRequestBuffer *buf = conn->transfer_queue; + bool sigpipe = false; switch (buf->type) { case MessageType_PlasmaDataRequest: - warn_if_sigpipe( + sigpipe = warn_if_sigpipe( plasma_send_DataRequest(conn->fd, state->builder, buf->object_id, state->addr, state->port), conn->fd); @@ -613,7 +626,7 @@ void send_queued_request(event_loop *loop, if (conn->cursor == 0) { /* If the cursor is zero, we haven't sent any requests for this object * yet, so send the initial data request. */ - warn_if_sigpipe( + sigpipe = warn_if_sigpipe( plasma_send_DataReply(conn->fd, state->builder, buf->object_id, buf->data_size, buf->metadata_size), conn->fd); @@ -624,6 +637,12 @@ void send_queued_request(event_loop *loop, LOG_FATAL("Buffered request has unknown type."); } + /* If there was a SIGPIPE, stop sending to this manager. */ + if (sigpipe) { + ClientConnection_free(conn); + return; + } + /* If we are done sending this request, remove it from the transfer queue. */ if (conn->cursor == 0) { if (buf->type == MessageType_PlasmaDataReply) { @@ -728,21 +747,13 @@ ClientConnection *get_manager_connection(PlasmaManagerState *state, utstring_len(ip_addr_port), manager_conn); if (!manager_conn) { /* If we don't already have a connection to this manager, start one. */ - int fd = connect_inet_sock_retry(ip_addr, port, -1, -1); - /* TODO(swang): Handle the case when connection to this manager was - * unsuccessful. */ - CHECK(fd >= 0); - manager_conn = (ClientConnection *) malloc(sizeof(ClientConnection)); - manager_conn->fd = fd; - manager_conn->manager_state = state; - manager_conn->transfer_queue = NULL; - manager_conn->pending_object_transfers = NULL; - manager_conn->cursor = 0; - manager_conn->ip_addr_port = strdup(utstring_body(ip_addr_port)); - HASH_ADD_KEYPTR(manager_hh, - manager_conn->manager_state->manager_connections, - manager_conn->ip_addr_port, - strlen(manager_conn->ip_addr_port), manager_conn); + int fd = connect_inet_sock(ip_addr, port); + if (fd < 0) { + return NULL; + } + + manager_conn = + ClientConnection_init(state, fd, utstring_body(ip_addr_port)); } utstring_free(ip_addr_port); return manager_conn; @@ -755,6 +766,9 @@ void process_transfer_request(event_loop *loop, ClientConnection *conn) { ClientConnection *manager_conn = get_manager_connection(conn->manager_state, addr, port); + if (manager_conn == NULL) { + return; + } /* If there is already a request in the transfer queue with the same object * ID, do not add the transfer request. */ @@ -765,6 +779,20 @@ void process_transfer_request(event_loop *loop, return; } + /* Allocate and append the request to the transfer queue. */ + ObjectBuffer object_buffer; + /* We pass in 0 to indicate that the command should return immediately. */ + plasma_get(conn->manager_state->plasma_conn, &obj_id, 1, 0, &object_buffer); + if (object_buffer.data_size == -1) { + /* If the object wasn't locally available, exit immediately. If the object + * later appears locally, the requesting plasma manager should request the + * transfer again. */ + LOG_WARN( + "Unable to transfer object to requesting plasma manager, object not " + "local."); + return; + } + /* If we already have a connection to this manager and its inactive, * (re)register it with the event loop again. */ if (manager_conn->transfer_queue == NULL) { @@ -772,38 +800,17 @@ void process_transfer_request(event_loop *loop, send_queued_request, manager_conn); } - /* Allocate and append the request to the transfer queue. */ - /* TODO(swang): A non-blocking plasma_get, or else we could block here - * forever if we don't end up sealing this object. */ - /* The corresponding call to plasma_release will happen in - * write_object_chunk. */ - /* TODO(rkn): The manager currently will block here if the object is not - * present in the store. This is completely unacceptable. The manager should - * do a non-blocking get call on the store, and if the object isn't there then - * perhaps the manager should initiate the transfer when it receives a - * notification from the store that the object is present. */ - ObjectBuffer obj_buffer; - int counter = 0; - do { - /* We pass in 0 to indicate that the command should return immediately. */ - ObjectID obj_id_array[1] = {obj_id}; - plasma_get(conn->manager_state->plasma_conn, obj_id_array, 1, 0, - &obj_buffer); - if (counter > 0) { - LOG_WARN("Blocking in the plasma manager."); - } - counter += 1; - } while (obj_buffer.data_size == -1); - DCHECK(obj_buffer.metadata == obj_buffer.data + obj_buffer.data_size); + DCHECK(object_buffer.metadata == + object_buffer.data + object_buffer.data_size); PlasmaRequestBuffer *buf = (PlasmaRequestBuffer *) malloc(sizeof(PlasmaRequestBuffer)); buf->type = MessageType_PlasmaDataReply; buf->object_id = obj_id; /* We treat buf->data as a pointer to the concatenated data and metadata, so * we don't actually use buf->metadata. */ - buf->data = obj_buffer.data; - buf->data_size = obj_buffer.data_size; - buf->metadata_size = obj_buffer.metadata_size; + buf->data = object_buffer.data; + buf->data_size = object_buffer.data_size; + buf->metadata_size = object_buffer.metadata_size; DL_APPEND(manager_conn->transfer_queue, buf); HASH_ADD(hh, manager_conn->pending_object_transfers, object_id, @@ -868,14 +875,7 @@ void process_data_request(event_loop *loop, } void request_transfer_from(PlasmaManagerState *manager_state, - ObjectID object_id) { - FetchRequest *fetch_req; - HASH_FIND(hh, manager_state->fetch_requests, &object_id, sizeof(object_id), - fetch_req); - /* TODO(rkn): This probably can be NULL so we should remove this check, and - * instead return in the case where there is no fetch request. */ - CHECK(fetch_req != NULL); - + FetchRequest *fetch_req) { CHECK(fetch_req->manager_count > 0); CHECK(fetch_req->next_manager >= 0 && fetch_req->next_manager < fetch_req->manager_count); @@ -886,30 +886,33 @@ void request_transfer_from(PlasmaManagerState *manager_state, ClientConnection *manager_conn = get_manager_connection(manager_state, addr, port); + if (manager_conn != NULL) { + /* Check that this manager isn't trying to request an object from itself. + * TODO(rkn): Later this should not be fatal. */ + uint8_t temp_addr[4]; + sscanf(addr, "%hhu.%hhu.%hhu.%hhu", &temp_addr[0], &temp_addr[1], + &temp_addr[2], &temp_addr[3]); + if (memcmp(temp_addr, manager_state->addr, 4) == 0 && + port == manager_state->port) { + LOG_FATAL( + "This manager is attempting to request a transfer from itself."); + } - /* Check that this manager isn't trying to request an object from itself. - * TODO(rkn): Later this should not be fatal. */ - uint8_t temp_addr[4]; - sscanf(addr, "%hhu.%hhu.%hhu.%hhu", &temp_addr[0], &temp_addr[1], - &temp_addr[2], &temp_addr[3]); - if (memcmp(temp_addr, manager_state->addr, 4) == 0 && - port == manager_state->port) { - LOG_FATAL("This manager is attempting to request a transfer from itself."); + PlasmaRequestBuffer *transfer_request = + (PlasmaRequestBuffer *) malloc(sizeof(PlasmaRequestBuffer)); + transfer_request->type = MessageType_PlasmaDataRequest; + transfer_request->object_id = fetch_req->object_id; + + if (manager_conn->transfer_queue == NULL) { + /* If we already have a connection to this manager and it's inactive, + * (re)register it with the event loop. */ + event_loop_add_file(manager_state->loop, manager_conn->fd, + EVENT_LOOP_WRITE, send_queued_request, manager_conn); + } + /* Add this transfer request to this connection's transfer queue. */ + DL_APPEND(manager_conn->transfer_queue, transfer_request); } - PlasmaRequestBuffer *transfer_request = - (PlasmaRequestBuffer *) malloc(sizeof(PlasmaRequestBuffer)); - transfer_request->type = MessageType_PlasmaDataRequest; - transfer_request->object_id = fetch_req->object_id; - - if (manager_conn->transfer_queue == NULL) { - /* If we already have a connection to this manager and its inactive, - * (re)register it with the event loop. */ - event_loop_add_file(manager_state->loop, manager_conn->fd, EVENT_LOOP_WRITE, - send_queued_request, manager_conn); - } - /* Add this transfer request to this connection's transfer queue. */ - DL_APPEND(manager_conn->transfer_queue, transfer_request); /* On the next attempt, try the next manager in manager_vector. */ fetch_req->next_manager += 1; fetch_req->next_manager %= fetch_req->manager_count; @@ -917,13 +920,39 @@ void request_transfer_from(PlasmaManagerState *manager_state, int fetch_timeout_handler(event_loop *loop, timer_id id, void *context) { PlasmaManagerState *manager_state = (PlasmaManagerState *) context; - /* Loop over the fetch requests and reissue the requests. */ + + /* Allocate a vector of object IDs to resend requests for location + * notifications. */ + int num_object_ids_to_request = 0; + int num_object_ids = HASH_COUNT(manager_state->fetch_requests); + /* This is allocating more space than necessary, but we do not know the exact + * number of object IDs to request notifications for yet. */ + ObjectID *object_ids_to_request = + (ObjectID *) malloc(num_object_ids * sizeof(ObjectID)); + + /* Loop over the fetch requests and reissue requests for objects whose + * locations we know. */ FetchRequest *fetch_req, *tmp; HASH_ITER(hh, manager_state->fetch_requests, fetch_req, tmp) { if (fetch_req->manager_count > 0) { - request_transfer_from(manager_state, fetch_req->object_id); + request_transfer_from(manager_state, fetch_req); + /* If we've tried all of the managers that we know about for this object, + * add this object to the list to resend requests for. */ + if (fetch_req->next_manager == 0) { + object_ids_to_request[num_object_ids_to_request] = fetch_req->object_id; + ++num_object_ids_to_request; + } } } + + /* Resend requests for notifications on these objects' locations. */ + if (num_object_ids_to_request > 0 && manager_state->db != NULL) { + object_table_request_notifications(manager_state->db, + num_object_ids_to_request, + object_ids_to_request, NULL); + } + free(object_ids_to_request); + return MANAGER_TIMEOUT; } @@ -980,7 +1009,7 @@ void request_transfer(ObjectID object_id, } /* Wait for the object data for the default number of retries, which timeout * after a default interval. */ - request_transfer_from(manager_state, object_id); + request_transfer_from(manager_state, fetch_req); } /* This method is only called from the tests. */ @@ -1351,7 +1380,9 @@ void process_object_notification(event_loop *loop, PlasmaManagerState *state = (PlasmaManagerState *) context; uint8_t *notification = read_message_async(loop, client_sock); if (notification == NULL) { - return; + PlasmaManagerState_free(state); + LOG_FATAL( + "Lost connection to the plasma store, plasma manager is exiting!"); } auto object_info = flatbuffers::GetRoot(notification); /* Add object to locally available object. */ @@ -1367,6 +1398,79 @@ void process_object_notification(event_loop *loop, free(notification); } +/* TODO(pcm): Split this into two methods: new_worker_connection + * and new_manager_connection and also split ClientConnection + * into two structs, one for workers and one for other plasma managers. */ +ClientConnection *ClientConnection_init(PlasmaManagerState *state, + int client_sock, + char *client_key) { + /* Create a new data connection context per client. */ + ClientConnection *conn = + (ClientConnection *) malloc(sizeof(ClientConnection)); + conn->manager_state = state; + conn->cursor = 0; + conn->transfer_queue = NULL; + conn->pending_object_transfers = NULL; + conn->fd = client_sock; + conn->num_return_objects = 0; + + conn->ip_addr_port = strdup(client_key); + HASH_ADD_KEYPTR(manager_hh, conn->manager_state->manager_connections, + conn->ip_addr_port, strlen(conn->ip_addr_port), conn); + return conn; +} + +ClientConnection *ClientConnection_listen(event_loop *loop, + int listener_sock, + void *context, + int events) { + PlasmaManagerState *state = (PlasmaManagerState *) context; + int new_socket = accept_client(listener_sock); + char client_key[8]; + snprintf(client_key, sizeof(client_key), "%d", new_socket); + ClientConnection *conn = ClientConnection_init(state, new_socket, client_key); + + event_loop_add_file(loop, new_socket, EVENT_LOOP_READ, process_message, conn); + LOG_DEBUG("New client connection with fd %d", new_socket); + return conn; +} + +void ClientConnection_free(ClientConnection *client_conn) { + PlasmaManagerState *state = client_conn->manager_state; + HASH_DELETE(manager_hh, state->manager_connections, client_conn); + /* Free the hash table of object IDs that are waiting to be transferred. */ + PlasmaRequestBuffer *request_buffer, *tmp_buffer; + HASH_ITER(hh, client_conn->pending_object_transfers, request_buffer, + tmp_buffer) { + /* We do not free the PlasmaRequestBuffer here because it is also in the + * transfer queue and will be freed below. */ + HASH_DELETE(hh, client_conn->pending_object_transfers, request_buffer); + } + + /* Free the transfer queue. */ + PlasmaRequestBuffer *head = client_conn->transfer_queue; + while (head) { + DL_DELETE(client_conn->transfer_queue, head); + free(head); + head = client_conn->transfer_queue; + } + /* Close the manager connection and free the remaining state. */ + close(client_conn->fd); + free(client_conn->ip_addr_port); + free(client_conn); +} + +void handle_new_client(event_loop *loop, + int listener_sock, + void *context, + int events) { + (void) ClientConnection_listen(loop, listener_sock, context, events); +} + +int get_client_sock(ClientConnection *conn) { + return conn->fd; +} + void process_message(event_loop *loop, int client_sock, void *context, @@ -1433,11 +1537,8 @@ void process_message(event_loop *loop, } break; case DISCONNECT_CLIENT: { LOG_INFO("Disconnecting client on fd %d", client_sock); - /* TODO(swang): Check if this connection was to a plasma manager. If so, - * delete it. */ event_loop_remove_file(loop, client_sock); - close(client_sock); - free(conn); + ClientConnection_free(conn); } break; default: LOG_FATAL("invalid request %" PRId64, type); @@ -1445,39 +1546,10 @@ void process_message(event_loop *loop, free(data); } -/* TODO(pcm): Split this into two methods: new_worker_connection - * and new_manager_connection and also split ClientConnection - * into two structs, one for workers and one for other plasma managers. */ -ClientConnection *ClientConnection_init(event_loop *loop, - int listener_sock, - void *context, - int events) { - int new_socket = accept_client(listener_sock); - /* Create a new data connection context per client. */ - ClientConnection *conn = - (ClientConnection *) malloc(sizeof(ClientConnection)); - conn->manager_state = (PlasmaManagerState *) context; - conn->cursor = 0; - conn->transfer_queue = NULL; - /* TODO(rkn): Is this pending_object_transfers hash table ever used? */ - conn->pending_object_transfers = NULL; - conn->fd = new_socket; - conn->active_objects = NULL; - conn->num_return_objects = 0; - event_loop_add_file(loop, new_socket, EVENT_LOOP_READ, process_message, conn); - LOG_DEBUG("New client connection with fd %d", new_socket); - return conn; -} - -void handle_new_client(event_loop *loop, - int listener_sock, - void *context, - int events) { - (void) ClientConnection_init(loop, listener_sock, context, events); -} - -int get_client_sock(ClientConnection *conn) { - return conn->fd; +int heartbeat_handler(event_loop *loop, timer_id id, void *context) { + PlasmaManagerState *state = (PlasmaManagerState *) context; + plasma_manager_send_heartbeat(state->db); + return HEARTBEAT_TIMEOUT_MILLISECONDS; } void start_server(const char *store_socket_name, @@ -1523,6 +1595,9 @@ void start_server(const char *store_socket_name, * requests and reissue requests for transfers of those objects. */ event_loop_add_timer(g_manager_state->loop, MANAGER_TIMEOUT, fetch_timeout_handler, g_manager_state); + /* Publish the heartbeats to all subscribers of the plasma manager table. */ + event_loop_add_timer(g_manager_state->loop, HEARTBEAT_TIMEOUT_MILLISECONDS, + heartbeat_handler, g_manager_state); /* Run the event loop. */ event_loop_run(g_manager_state->loop); } diff --git a/src/plasma/plasma_manager.h b/src/plasma/plasma_manager.h index ccf68bc17..b65880f8f 100644 --- a/src/plasma/plasma_manager.h +++ b/src/plasma/plasma_manager.h @@ -154,10 +154,10 @@ void send_queued_request(event_loop *loop, * @param context The plasma manager state. * @return Void. */ -ClientConnection *ClientConnection_init(event_loop *loop, - int listener_sock, - void *context, - int events); +ClientConnection *ClientConnection_listen(event_loop *loop, + int listener_sock, + void *context, + int events); /** * The following definitions are internal to the plasma manager code but are diff --git a/src/plasma/test/manager_tests.cc b/src/plasma/test/manager_tests.cc index c962957c2..3514b8e17 100644 --- a/src/plasma/test/manager_tests.cc +++ b/src/plasma/test/manager_tests.cc @@ -77,8 +77,8 @@ plasma_mock *init_plasma_mock(plasma_mock *remote_mock) { get_manager_connection(remote_mock->state, manager_addr, mock->port); wait_for_pollin(mock->manager_remote_fd); mock->read_conn = - ClientConnection_init(mock->loop, mock->manager_remote_fd, mock->state, - PLASMA_DEFAULT_RELEASE_DELAY); + ClientConnection_listen(mock->loop, mock->manager_remote_fd, + mock->state, PLASMA_DEFAULT_RELEASE_DELAY); } else { mock->write_conn = NULL; mock->read_conn = NULL; @@ -88,19 +88,14 @@ plasma_mock *init_plasma_mock(plasma_mock *remote_mock) { mock->plasma_conn = plasma_connect(plasma_store_socket_name, utstring_body(manager_socket_name), 0); wait_for_pollin(mock->manager_local_fd); - mock->client_conn = - ClientConnection_init(mock->loop, mock->manager_local_fd, mock->state, 0); + mock->client_conn = ClientConnection_listen( + mock->loop, mock->manager_local_fd, mock->state, 0); utstring_free(manager_socket_name); return mock; } void destroy_plasma_mock(plasma_mock *mock) { - if (mock->read_conn != NULL) { - close(get_client_sock(mock->read_conn)); - free(mock->read_conn); - } PlasmaManagerState_free(mock->state); - free(mock->client_conn); plasma_disconnect(mock->plasma_conn); close(mock->local_store); close(mock->manager_local_fd); diff --git a/test/component_failures_test.py b/test/component_failures_test.py index 12da5e293..49fc7d8ef 100644 --- a/test/component_failures_test.py +++ b/test/component_failures_test.py @@ -8,6 +8,10 @@ import time import unittest class ComponentFailureTest(unittest.TestCase): + + def tearDown(self): + ray.worker.cleanup() + # This test checks that when a worker dies in the middle of a get, the plasma # store and manager will not die. def testDyingWorkerGet(self): @@ -37,7 +41,6 @@ class ComponentFailureTest(unittest.TestCase): # Make sure that nothing has died. self.assertTrue(ray.services.all_processes_alive(exclude=[ray.services.PROCESS_TYPE_WORKER])) - ray.worker.cleanup() # This test checks that when a worker dies in the middle of a wait, the plasma # store and manager will not die. @@ -68,7 +71,6 @@ class ComponentFailureTest(unittest.TestCase): # Make sure that nothing has died. self.assertTrue(ray.services.all_processes_alive(exclude=[ray.services.PROCESS_TYPE_WORKER])) - ray.worker.cleanup() def _testWorkerFailed(self, num_local_schedulers): @ray.remote @@ -95,15 +97,15 @@ class ComponentFailureTest(unittest.TestCase): # Make sure that we can still get the objects after the executing tasks died. ray.get(object_ids) - ray.worker.cleanup() - def testWorkerFailed(self): self._testWorkerFailed(1) def testWorkerFailedMultinode(self): self._testWorkerFailed(4) - def testNodeFailed(self): + def _testComponentFailed(self, component_type): + """Kill a component on all worker nodes and check that workload succeeds. + """ @ray.remote def f(x, j): time.sleep(0.2) @@ -123,11 +125,16 @@ class ComponentFailureTest(unittest.TestCase): object_ids += [f.remote(object_id, 1) for object_id in object_ids] object_ids += [f.remote(object_id, 2) for object_id in object_ids] - # Kill all nodes except the head node as the tasks execute. + # Kill the component on all nodes except the head node as the tasks + # execute. time.sleep(0.1) - local_schedulers = ray.services.all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER] - for process in local_schedulers[1:]: + components = ray.services.all_processes[component_type] + for process in components[1:]: process.terminate() + time.sleep(0.1) + process.kill() + process.wait() + self.assertNotEqual(process.poll(), None) time.sleep(1) # Make sure that we can still get the objects after the executing tasks @@ -136,8 +143,84 @@ class ComponentFailureTest(unittest.TestCase): expected_results = 4 * list(range(num_workers_per_scheduler * num_local_schedulers)) self.assertEqual(results, expected_results) - ray.worker.cleanup() + def check_components_alive(self, component_type, check_component_alive): + """Check that a given component type is alive on all worker nodes. + """ + components = ray.services.all_processes[component_type][1:] + for component in components: + if check_component_alive: + self.assertTrue(component.poll() is None) + else: + self.assertTrue(component.poll() <= 0) + def testLocalSchedulerFailed(self): + # Kill all local schedulers on worker nodes. + self._testComponentFailed(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER) + + # The plasma stores and plasma managers should still be alive on the worker + # nodes. + self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, True) + self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, True) + self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, False) + + def testPlasmaManagerFailed(self): + # Kill all plasma managers on worker nodes. + self._testComponentFailed(ray.services.PROCESS_TYPE_PLASMA_MANAGER) + + # The plasma stores should still be alive (but unreachable) on the worker + # nodes. + self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, True) + self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, False) + self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, False) + + def testPlasmaStoreFailed(self): + # Kill all plasma stores on worker nodes. + self._testComponentFailed(ray.services.PROCESS_TYPE_PLASMA_STORE) + + # No processes should be left alive on the worker nodes. + self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_STORE, False) + self.check_components_alive(ray.services.PROCESS_TYPE_PLASMA_MANAGER, False) + self.check_components_alive(ray.services.PROCESS_TYPE_LOCAL_SCHEDULER, False) + + def testDriverLivesSequential(self): + ray.worker.init() + processes = [ + ray.services.all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0], + ray.services.all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0], + ray.services.all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0], + ray.services.all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0], + ] + + # Kill all the components sequentially. + for process in processes: + process.terminate() + time.sleep(0.1) + process.kill() + process.wait() + + # If the driver can reach the tearDown method, then it is still alive. + + def testDriverLivesParallel(self): + ray.worker.init() + processes = [ + ray.services.all_processes[ray.services.PROCESS_TYPE_PLASMA_STORE][0], + ray.services.all_processes[ray.services.PROCESS_TYPE_PLASMA_MANAGER][0], + ray.services.all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER][0], + ray.services.all_processes[ray.services.PROCESS_TYPE_GLOBAL_SCHEDULER][0], + ] + + # Kill all the components in parallel. + for process in processes: + process.terminate() + + time.sleep(0.1) + for process in processes: + process.kill() + + for process in processes: + process.wait() + + # If the driver can reach the tearDown method, then it is still alive. if __name__ == "__main__": unittest.main(verbosity=2)