diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 1059e13fc..147fe39d8 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -6,6 +6,8 @@ def get_local_schedulers(worker): local_schedulers = [] for client in worker.redis_client.keys("CL:*"): client_info = worker.redis_client.hgetall(client) + if b"client_type" not in client_info: + continue if client_info[b"client_type"] == b"local_scheduler": local_schedulers.append(client_info) return local_schedulers diff --git a/python/ray/monitor.py b/python/ray/monitor.py new file mode 100644 index 000000000..b52d612a0 --- /dev/null +++ b/python/ray/monitor.py @@ -0,0 +1,187 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import binascii +from collections import Counter +import logging +import redis +import time + +from ray.services import get_ip_address +from ray.services import get_port + +# These variables must be kept in sync with the C codebase. +# common/common.h +DB_CLIENT_ID_SIZE = 20 +NIL_ID = b"\xff" * DB_CLIENT_ID_SIZE +# common/task.h +TASK_STATUS_LOST = 32 +# common/redis_module/ray_redis_module.c +TASK_PREFIX = "TT:" +DB_CLIENT_PREFIX = "CL:" +DB_CLIENT_TABLE_NAME = b"db_clients" +# local_scheduler/local_scheduler.h +LOCAL_SCHEDULER_HEARTBEAT_TIMEOUT_MILLISECONDS = 100 +LOCAL_SCHEDULER_CLIENT_TYPE = b"local_scheduler" + +# Set up logging. +logging.basicConfig() +log = logging.getLogger() + +class Monitor(object): + """A monitor for Ray processes. + + The monitor is in charge of cleaning up the tables in the global state after + processes have died. The monitor is currently not responsible for detecting + component failures. + + Attributes: + redis: A connection to the Redis server. + subscribe_client: A pubsub client for the Redis server. This is used to + receive notifications about failed components. + local_schedulers: A set of the local scheduler IDs of all of the currently + live local schedulers in the cluster. In addition, this also includes + NIL_ID. + """ + def __init__(self, redis_address, redis_port): + self.redis = redis.StrictRedis(host=redis_address, port=redis_port, db=0) + self.subscribe_client = self.redis.pubsub() + + # Initialize data structures to keep track of the active database clients. + self.local_schedulers = set() + # Add the NIL_ID so that we don't accidentally mark tasks that aren't + # associated with a node as LOST during cleanup. + self.local_schedulers.add(NIL_ID) + + def subscribe(self): + """Subscribe to the db_clients channel. + + Raises: + Exception: An exception is raised if the subscription fails. + """ + self.subscribe_client.subscribe(DB_CLIENT_TABLE_NAME) + # Wait for the first message to signal that the subscription was successful. + while True: + message = self.subscribe_client.get_message() + if message is None: + time.sleep(LOCAL_SCHEDULER_HEARTBEAT_TIMEOUT_MILLISECONDS / 1000) + continue + break + + # The first message's payload should be the index of our subscription. + if "data" not in message: + Exception("Unable to subscribe to local scheduler table.") + + def read_message(self): + """Read a message from the db_clients channel. + + Returns: + None if no message was to read. Otherwise, a tuple of (db_client_id, + client_type, auxiliary_address, is_insertion) is returned. The value + is_insertion is a bool that is true if the update to the db_clients + table was an insertion and false if deletion. + """ + message = self.subscribe_client.get_message() + if message is None: + return None + + # Parse the message. + data = message["data"] + db_client_id = data[:DB_CLIENT_ID_SIZE] + data = data[DB_CLIENT_ID_SIZE + 1:] + data = data.split(b" ") + client_type, auxiliary_address, is_insertion = data + is_insertion = int(is_insertion) + if is_insertion != 1 and is_insertion != 0: + raise Exception("Expected 0 or 1 for insertion field, got {} instead".format(is_insertion)) + is_insertion = bool(is_insertion) + + return db_client_id, client_type, auxiliary_address, is_insertion + + def cleanup_task_table(self): + """Clean up global state for a failed local schedulers. + + This marks any tasks that were scheduled on dead local schedulers as + TASK_STATUS_LOST. A local scheduler is deemed dead if it is not in + self.local_schedulers. + """ + task_ids = self.redis.scan_iter(match="{prefix}*".format(prefix=TASK_PREFIX)) + for task_id in task_ids: + task_id = task_id[len(TASK_PREFIX):] + response = self.redis.execute_command("RAY.TASK_TABLE_GET", task_id) + if response[1] not in self.local_schedulers: + ok = self.redis.execute_command("RAY.TASK_TABLE_UPDATE", + task_id, + TASK_STATUS_LOST, + NIL_ID) + if ok != b"OK": + log.warn("Failed to update lost task for dead scheduler.") + + def scan_db_client_table(self): + """Scan the database client table for the current clients. + + After subscribing to the client table, it's necessary to call this before + reading any messages from the subscription channel. + """ + db_client_keys = self.redis.keys("{prefix}*".format(prefix=DB_CLIENT_PREFIX)) + for db_client_key in db_client_keys: + db_client_id = db_client_key[len(DB_CLIENT_PREFIX):] + client_type = self.redis.hget(db_client_key, "client_type") + if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: + self.local_schedulers.add(db_client_id) + + def run(self): + """Run the monitor. + + This function loops forever, checking for messages about dead database + clients and cleaning up state accordingly. + """ + # Initialize the subscription channel. + self.subscribe() + + # Scan the database table and clean up any state associated with clients + # not in the database table. NOTE: This must be called before reading any + # messages from the subscription channel. This ensures that we start in a + # consistent state, since we may have missed notifications that were sent + # before we connected to the subscription channel. + self.scan_db_client_table() + self.cleanup_task_table() + log.debug("Scanned schedulers: {}".format(self.local_schedulers)) + + # Read messages from the subscription channel. + while True: + time.sleep(LOCAL_SCHEDULER_HEARTBEAT_TIMEOUT_MILLISECONDS / 1000) + client = self.read_message() + # There was no message to be read. + if client is None: + continue + + db_client_id, client_type, auxiliary_address, is_insertion = client + + # If the update was an insertion, record the client ID. + if is_insertion: + self.local_schedulers.add(db_client_id) + log.debug("Added scheduler: {}".format(db_client_id)) + continue + + # If the update was a deletion, clean up global state. + if client_type == LOCAL_SCHEDULER_CLIENT_TYPE: + if db_client_id in self.local_schedulers: + log.warn("Removed scheduler: {}".format(db_client_id)) + self.local_schedulers.remove(db_client_id) + self.cleanup_task_table() + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description=("Parse Redis server for the " + "monitor to connect to.")) + parser.add_argument("--redis-address", required=True, type=str, + help="the address to use for Redis") + args = parser.parse_args() + + redis_ip_address = get_ip_address(args.redis_address) + redis_port = get_port(args.redis_address) + + monitor = Monitor(redis_ip_address, redis_port) + monitor.run() diff --git a/python/ray/services.py b/python/ray/services.py index 7e91f002d..b50049553 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -21,6 +21,7 @@ import ray.local_scheduler as local_scheduler import ray.plasma as plasma import ray.global_scheduler as global_scheduler +PROCESS_TYPE_MONITOR = "monitor" PROCESS_TYPE_WORKER = "worker" PROCESS_TYPE_LOCAL_SCHEDULER = "local_scheduler" PROCESS_TYPE_PLASMA_MANAGER = "plasma_manager" @@ -34,13 +35,14 @@ PROCESS_TYPE_WEB_UI = "web_ui" # important because it determines the order in which these processes will be # terminated when Ray exits, and certain orders will cause errors to be logged # to the screen. -all_processes = OrderedDict([(PROCESS_TYPE_WORKER, []), +all_processes = OrderedDict([(PROCESS_TYPE_MONITOR, []), + (PROCESS_TYPE_WORKER, []), (PROCESS_TYPE_LOCAL_SCHEDULER, []), (PROCESS_TYPE_PLASMA_MANAGER, []), (PROCESS_TYPE_PLASMA_STORE, []), (PROCESS_TYPE_GLOBAL_SCHEDULER, []), (PROCESS_TYPE_REDIS_SERVER, []), - (PROCESS_TYPE_WEB_UI, [])]) + (PROCESS_TYPE_WEB_UI, [])],) # True if processes are run in the valgrind profiler. RUN_LOCAL_SCHEDULER_PROFILER = False @@ -527,7 +529,7 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name, object_store_name (str): The name of the object store. object_store_manager_name (str): The name of the object store manager. local_scheduler_name (str): The name of the local scheduler. - redis_address (int): The address that the Redis server is listening on. + redis_address (str): The address that the Redis server is listening on. worker_path (str): The path of the source code which the worker process will run. stdout_file: A file handle opened for writing to redirect stdout to. If no @@ -549,6 +551,28 @@ def start_worker(node_ip_address, object_store_name, object_store_manager_name, if cleanup: all_processes[PROCESS_TYPE_WORKER].append(p) +def start_monitor(redis_address, stdout_file=None, stderr_file=None, + cleanup=True): + """Run a process to monitor the other processes. + + Args: + redis_address (str): The address that the Redis server is listening on. + stdout_file: A file handle opened for writing to redirect stdout to. If no + redirection should happen, then this should be None. + stderr_file: A file handle opened for writing to redirect stderr to. If no + redirection should happen, then this should be None. + cleanup (bool): True if using Ray in local mode. If cleanup is true, then + this process will be killed by services.cleanup() when the Python process + that imported services exits. This is True by default. + """ + monitor_path= os.path.join(os.path.dirname(os.path.abspath(__file__)), "monitor.py") + command = ["python", + monitor_path, + "--redis-address=" + str(redis_address)] + p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) + if cleanup: + all_processes[PROCESS_TYPE_WORKER].append(p) + def start_ray_processes(address_info=None, node_ip_address="127.0.0.1", num_workers=0, @@ -641,6 +665,11 @@ def start_ray_processes(address_info=None, stderr_file=redis_stderr_file, cleanup=cleanup) assert redis_port == new_redis_port + # Start monitoring the processes. + monitor_stdout_file, monitor_stderr_file = new_log_files("monitor", redirect_output) + start_monitor(redis_address, + stdout_file=monitor_stdout_file, + stderr_file=monitor_stderr_file) else: if redis_address is None: raise Exception("Redis address expected") diff --git a/python/ray/worker.py b/python/ray/worker.py index 23fc32ba6..9ce571ec7 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -958,9 +958,13 @@ def cleanup(worker=global_worker): {"end_time": time.time()}) services.cleanup() else: - # If this is not a driver, make sure there are no orphan processes. + # If this is not a driver, make sure there are no orphan processes, besides + # possibly the worker itself. for process_type, processes in services.all_processes.items(): - assert(len(processes) == 0) + if process_type == services.PROCESS_TYPE_WORKER: + assert(len(processes)) <= 1 + else: + assert(len(processes) == 0) worker.set_mode(None) diff --git a/scripts/stop_ray.sh b/scripts/stop_ray.sh index 59765829c..d983fe0d7 100755 --- a/scripts/stop_ray.sh +++ b/scripts/stop_ray.sh @@ -2,6 +2,9 @@ killall global_scheduler plasma_store plasma_manager local_scheduler +# Find the PID of the monitor process and kill it. +kill $(ps aux | grep monitor.py | awk '{ print $2 }') 2> /dev/null + # Find the PID of the Redis process and kill it. kill $(ps aux | grep redis-server | awk '{ print $2 }') 2> /dev/null diff --git a/src/common/redis_module/ray_redis_module.c b/src/common/redis_module/ray_redis_module.c index 8e5d094f6..8a940d9f1 100644 --- a/src/common/redis_module/ray_redis_module.c +++ b/src/common/redis_module/ray_redis_module.c @@ -47,6 +47,47 @@ RedisModuleKey *OpenPrefixedKey(RedisModuleCtx *ctx, return key; } +/** + * Publish a notification to a client's notification channel about an insertion + * or deletion to the db client table. + * + * @param ctx The Redis context. + * @param ray_client_id The ID of the database client that was inserted or + * deleted. + * @param client_type The type of client that was inserted or deleted. + * @param aux_address An optional secondary address associated with the + * database client. + * @param is_insertion A boolean that's true if the update was an insertion and + * false if deletion. + * @return True if the publish was successful and false otherwise. + */ +bool PublishDBClientNotification(RedisModuleCtx *ctx, + RedisModuleString *ray_client_id, + RedisModuleString *client_type, + RedisModuleString *aux_address, + bool is_insertion) { + /* Construct strings to publish on the db client channel. */ + RedisModuleString *channel_name = + RedisModule_CreateString(ctx, "db_clients", strlen("db_clients")); + RedisModuleString *client_info; + const char *is_insertion_string = is_insertion ? "1" : "0"; + if (aux_address) { + client_info = + RedisString_Format(ctx, "%S:%S %S %s", ray_client_id, client_type, + aux_address, is_insertion_string); + } else { + client_info = RedisString_Format(ctx, "%S:%S : %s", ray_client_id, + client_type, is_insertion_string); + } + + /* Publish the client info on the db client channel. */ + RedisModuleCallReply *reply; + reply = RedisModule_Call(ctx, "PUBLISH", "ss", channel_name, client_info); + RedisModule_FreeString(ctx, channel_name); + RedisModule_FreeString(ctx, client_info); + return (reply != NULL); +} + /** * Register a client with Redis. This is called from a client with the command: * @@ -110,25 +151,65 @@ int Connect_RedisCommand(RedisModuleCtx *ctx, /* Clean up. */ RedisModule_FreeString(ctx, aux_address_key); RedisModule_CloseKey(db_client_table_key); - - /* Construct strings to publish on the db client channel. */ - RedisModuleString *channel_name = - RedisModule_CreateString(ctx, "db_clients", strlen("db_clients")); - RedisModuleString *client_info; - if (aux_address) { - client_info = RedisString_Format(ctx, "%S:%S %S", ray_client_id, - client_type, aux_address); - } else { - client_info = - RedisString_Format(ctx, "%S:%S :", ray_client_id, client_type); + if (!PublishDBClientNotification(ctx, ray_client_id, client_type, aux_address, + true)) { + return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); } - /* Publish the client info on the db client channel. */ - RedisModuleCallReply *reply; - reply = RedisModule_Call(ctx, "PUBLISH", "ss", channel_name, client_info); - RedisModule_FreeString(ctx, channel_name); - RedisModule_FreeString(ctx, client_info); - if (reply == NULL) { + RedisModule_ReplyWithSimpleString(ctx, "OK"); + return REDISMODULE_OK; +} + +/** + * Remove a client from Redis. This is called from a client with the command: + * + * RAY.DISCONNECT + * + * This method also publishes a notification to all subscribers to the + * db_clients channel. The notification consists of a message of the form ":". + * + * @param ray_client_id The db client ID of the client. + * @return OK if the operation was successful. + */ +int Disconnect_RedisCommand(RedisModuleCtx *ctx, + RedisModuleString **argv, + int argc) { + if (argc != 2) { + return RedisModule_WrongArity(ctx); + } + + RedisModuleString *ray_client_id = argv[1]; + + /* Get the client type. */ + RedisModuleKey *db_client_table_key = + OpenPrefixedKey(ctx, DB_CLIENT_PREFIX, ray_client_id, REDISMODULE_WRITE); + if (RedisModule_KeyType(db_client_table_key) == REDISMODULE_KEYTYPE_EMPTY) { + /* Someone else already deleted this client. */ + RedisModule_CloseKey(db_client_table_key); + RedisModule_ReplyWithSimpleString(ctx, "OK"); + return REDISMODULE_OK; + } + + RedisModuleString *client_type; + RedisModuleString *aux_address; + RedisModule_HashGet(db_client_table_key, REDISMODULE_HASH_CFIELDS, + "client_type", &client_type, "aux_address", &aux_address, + NULL); + + /* Remove the client from the client table. */ + CHECK_ERROR(RedisModule_DeleteKey(db_client_table_key), + "Unable to delete db client key."); + RedisModule_CloseKey(db_client_table_key); + + /* Publish the deletion notification on the db client channel. */ + bool published = PublishDBClientNotification(ctx, ray_client_id, client_type, + aux_address, false); + + RedisModule_FreeString(ctx, aux_address); + RedisModule_FreeString(ctx, client_type); + + if (!published) { return RedisModule_ReplyWithError(ctx, "PUBLISH unsuccessful"); } @@ -968,7 +1049,12 @@ int RedisModule_OnLoad(RedisModuleCtx *ctx, } if (RedisModule_CreateCommand(ctx, "ray.connect", Connect_RedisCommand, - "write", 0, 0, 0) == REDISMODULE_ERR) { + "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { + return REDISMODULE_ERR; + } + + if (RedisModule_CreateCommand(ctx, "ray.disconnect", Disconnect_RedisCommand, + "write pubsub", 0, 0, 0) == REDISMODULE_ERR) { return REDISMODULE_ERR; } diff --git a/src/common/state/db_client_table.cc b/src/common/state/db_client_table.cc index 7414baded..dbf75cba8 100644 --- a/src/common/state/db_client_table.cc +++ b/src/common/state/db_client_table.cc @@ -1,6 +1,16 @@ #include "db_client_table.h" #include "redis.h" +void db_client_table_remove(DBHandle *db_handle, + DBClientID db_client_id, + RetryInfo *retry, + db_client_table_done_callback done_callback, + void *user_context) { + init_table_callback(db_handle, db_client_id, __func__, NULL, retry, + (table_done_callback) done_callback, + redis_db_client_table_remove, user_context); +} + void db_client_table_subscribe( DBHandle *db_handle, db_client_table_subscribe_callback subscribe_callback, diff --git a/src/common/state/db_client_table.h b/src/common/state/db_client_table.h index 033dc8a77..c5fa4cf50 100644 --- a/src/common/state/db_client_table.h +++ b/src/common/state/db_client_table.h @@ -7,6 +7,24 @@ typedef void (*db_client_table_done_callback)(DBClientID db_client_id, void *user_context); +/** + * Remove a client from the db clients table. + * + * @param db_handle Database handle. + * @param db_client_id The database client ID to remove. + * @param retry Information about retrying the request to the database. + * @param done_callback Function to be called when database returns result. + * @param user_context Data that will be passed to done_callback and + * fail_callback. + * @return Void. + * + */ +void db_client_table_remove(DBHandle *db_handle, + DBClientID db_client_id, + RetryInfo *retry, + db_client_table_done_callback done_callback, + void *user_context); + /* * ==== Subscribing to the db client table ==== */ @@ -15,6 +33,7 @@ typedef void (*db_client_table_done_callback)(DBClientID db_client_id, typedef void (*db_client_table_subscribe_callback)(DBClientID db_client_id, const char *client_type, const char *aux_address, + bool is_insertion, void *user_context); /** diff --git a/src/common/state/redis.cc b/src/common/state/redis.cc index 952764e4f..3e3662dd7 100644 --- a/src/common/state/redis.cc +++ b/src/common/state/redis.cc @@ -992,6 +992,36 @@ void redis_task_table_subscribe(TableCallbackData *callback_data) { * ==== db client table callbacks ==== */ +void redis_db_client_table_remove_callback(redisAsyncContext *c, + void *r, + void *privdata) { + REDIS_CALLBACK_HEADER(db, callback_data, r); + redisReply *reply = (redisReply *) r; + + CHECK(reply->type != REDIS_REPLY_ERROR); + CHECK(strcmp(reply->str, "OK") == 0); + + /* Call the done callback if there is one. */ + db_client_table_done_callback done_callback = + (db_client_table_done_callback) callback_data->done_callback; + if (done_callback) { + done_callback(callback_data->id, callback_data->user_context); + } + /* Clean up the timer and callback. */ + destroy_timer_callback(db->loop, callback_data); +} + +void redis_db_client_table_remove(TableCallbackData *callback_data) { + DBHandle *db = callback_data->db_handle; + int status = + redisAsyncCommand(db->context, redis_db_client_table_remove_callback, + (void *) callback_data->timer_id, "RAY.DISCONNECT %b", + callback_data->id.id, sizeof(callback_data->id.id)); + if ((status == REDIS_ERR) || db->context->err) { + LOG_REDIS_DEBUG(db->context, "error in db_client_table_remove"); + } +} + void redis_db_client_table_subscribe_callback(redisAsyncContext *c, void *r, void *privdata) { @@ -1024,19 +1054,30 @@ void redis_db_client_table_subscribe_callback(redisAsyncContext *c, /* We subtract 1 + sizeof(client.id) to compute the length of the * client_type string, and we add 1 to null-terminate the string. */ int client_type_length = payload->len - 1 - sizeof(client.id) + 1; + CHECK(client_type_length > 0); + + /* Parse the client type and auxiliary address from the response. If there is + * only client type, then the update was a delete. */ char *client_type = (char *) malloc(client_type_length); char *aux_address = (char *) malloc(client_type_length); + int is_insertion; memset(aux_address, 0, client_type_length); /* Published message format: */ - int rv = sscanf(&payload->str[1 + sizeof(client.id)], "%s %s", client_type, - aux_address); - CHECKM(rv == 2, + int rv = sscanf(&payload->str[1 + sizeof(client.id)], "%s %s %d", client_type, + aux_address, &is_insertion); + CHECKM(rv == 3, "redis_db_client_table_subscribe_callback: expected 2 parsed args, " "Got %d instead.", rv); + CHECKM(is_insertion == 1 || is_insertion == 0, + "redis_db_client_table_subscribe_callback: expected 0 or 1 for " + "insertion field, got %d instead.", + is_insertion); + + /* Call the subscription callback. */ if (data->subscribe_callback) { data->subscribe_callback(client, client_type, aux_address, - data->subscribe_context); + (bool) is_insertion, data->subscribe_context); } free(client_type); free(aux_address); diff --git a/src/common/state/redis.h b/src/common/state/redis.h index 95e1ebf0c..37e075528 100644 --- a/src/common/state/redis.h +++ b/src/common/state/redis.h @@ -217,6 +217,15 @@ void redis_task_table_publish_publish_callback(redisAsyncContext *c, */ void redis_task_table_subscribe(TableCallbackData *callback_data); +/** + * Remove a client from the db clients table. + * + * @param callback_data Data structure containing redis connection and timeout + * information. + * @return Void. + */ +void redis_db_client_table_remove(TableCallbackData *callback_data); + /** * Subscribe to updates from the db client table. * diff --git a/src/common/state/table.cc b/src/common/state/table.cc index 835a0a433..e092ecaa0 100644 --- a/src/common/state/table.cc +++ b/src/common/state/table.cc @@ -81,7 +81,7 @@ int64_t table_timeout_handler(event_loop *loop, CHECK(callback_data->retry.num_retries >= 0 || callback_data->retry.num_retries == -1); - LOG_WARN("retrying operation, retry_count = %d", + LOG_WARN("retrying operation %s, retry_count = %d", callback_data->label, callback_data->retry.num_retries); if (callback_data->retry.num_retries == 0) { diff --git a/src/global_scheduler/global_scheduler.cc b/src/global_scheduler/global_scheduler.cc index ab3b5beb4..575bb06b2 100644 --- a/src/global_scheduler/global_scheduler.cc +++ b/src/global_scheduler/global_scheduler.cc @@ -167,6 +167,87 @@ void process_task_waiting(Task *waiting_task, void *user_context) { } } +void add_local_scheduler(GlobalSchedulerState *state, + DBClientID db_client_id, + const char *aux_address) { + /* Add plasma_manager ip:port -> local_scheduler_db_client_id association to + * state. */ + AuxAddressEntry *plasma_local_scheduler_entry = + (AuxAddressEntry *) calloc(1, sizeof(AuxAddressEntry)); + plasma_local_scheduler_entry->aux_address = strdup(aux_address); + plasma_local_scheduler_entry->local_scheduler_db_client_id = db_client_id; + HASH_ADD_KEYPTR(plasma_local_scheduler_hh, state->plasma_local_scheduler_map, + plasma_local_scheduler_entry->aux_address, + strlen(plasma_local_scheduler_entry->aux_address), + plasma_local_scheduler_entry); + + /* Add local_scheduler_db_client_id -> plasma_manager ip:port association to + * state. */ + HASH_ADD(local_scheduler_plasma_hh, state->local_scheduler_plasma_map, + local_scheduler_db_client_id, + sizeof(plasma_local_scheduler_entry->local_scheduler_db_client_id), + plasma_local_scheduler_entry); + +#if (RAY_COMMON_LOG_LEVEL <= RAY_COMMON_DEBUG) + { + /* Print the local scheduler to plasma association map so far. */ + AuxAddressEntry *entry, *tmp; + LOG_DEBUG("Local scheduler to plasma hash map so far:"); + HASH_ITER(plasma_local_scheduler_hh, state->plasma_local_scheduler_map, + entry, tmp) { + LOG_DEBUG("%s -> %s", entry->aux_address, + ObjectID_to_string(entry->local_scheduler_db_client_id, + id_string, ID_STRING_SIZE)); + } + } +#endif + + /* Add new local scheduler to the state. */ + LocalScheduler local_scheduler; + local_scheduler.id = db_client_id; + local_scheduler.num_heartbeats_missed = 0; + local_scheduler.num_tasks_sent = 0; + local_scheduler.num_recent_tasks_sent = 0; + local_scheduler.info.task_queue_length = 0; + local_scheduler.info.available_workers = 0; + memset(local_scheduler.info.dynamic_resources, 0, + sizeof(local_scheduler.info.dynamic_resources)); + memset(local_scheduler.info.static_resources, 0, + sizeof(local_scheduler.info.static_resources)); + utarray_push_back(state->local_schedulers, &local_scheduler); + + /* Allow the scheduling algorithm to process this event. */ + handle_new_local_scheduler(state, state->policy_state, db_client_id); +} + +void remove_local_scheduler(GlobalSchedulerState *state, int index) { + LocalScheduler *active_worker = + (LocalScheduler *) utarray_eltptr(state->local_schedulers, index); + DBClientID db_client_id = active_worker->id; + utarray_erase(state->local_schedulers, index, 1); + + AuxAddressEntry *entry, *tmp; + HASH_ITER(plasma_local_scheduler_hh, state->plasma_local_scheduler_map, entry, + tmp) { + if (DBClientID_equal(entry->local_scheduler_db_client_id, db_client_id)) { + HASH_DELETE(plasma_local_scheduler_hh, state->plasma_local_scheduler_map, + entry); + /* The hash entry is shared with the local_scheduler_plasma hashmap and + * will be freed there. */ + free(entry->aux_address); + } + } + + HASH_FIND(local_scheduler_plasma_hh, state->local_scheduler_plasma_map, + &db_client_id, sizeof(db_client_id), entry); + CHECK(entry != NULL); + HASH_DELETE(local_scheduler_plasma_hh, state->local_scheduler_plasma_map, + entry); + free(entry); + + handle_local_scheduler_removed(state, state->policy_state, db_client_id); +} + /** * Process a notification about a new DB client connecting to Redis. * @param aux_address: an ip:port pair for the plasma manager associated with @@ -175,6 +256,7 @@ void process_task_waiting(Task *waiting_task, void *user_context) { void process_new_db_client(DBClientID db_client_id, const char *client_type, const char *aux_address, + bool is_insertion, void *user_context) { GlobalSchedulerState *state = (GlobalSchedulerState *) user_context; char id_string[ID_STRING_SIZE]; @@ -182,54 +264,22 @@ void process_new_db_client(DBClientID db_client_id, ObjectID_to_string(db_client_id, id_string, ID_STRING_SIZE)); UNUSED(id_string); if (strncmp(client_type, "local_scheduler", strlen("local_scheduler")) == 0) { - /* Add plasma_manager ip:port -> local_scheduler_db_client_id association to - * state. */ - AuxAddressEntry *plasma_local_scheduler_entry = - (AuxAddressEntry *) calloc(1, sizeof(AuxAddressEntry)); - plasma_local_scheduler_entry->aux_address = strdup(aux_address); - plasma_local_scheduler_entry->local_scheduler_db_client_id = db_client_id; - HASH_ADD_KEYPTR(plasma_local_scheduler_hh, - state->plasma_local_scheduler_map, - plasma_local_scheduler_entry->aux_address, - strlen(plasma_local_scheduler_entry->aux_address), - plasma_local_scheduler_entry); - - /* Add local_scheduler_db_client_id -> plasma_manager ip:port association to - * state. */ - HASH_ADD(local_scheduler_plasma_hh, state->local_scheduler_plasma_map, - local_scheduler_db_client_id, - sizeof(plasma_local_scheduler_entry->local_scheduler_db_client_id), - plasma_local_scheduler_entry); - -#if (RAY_COMMON_LOG_LEVEL <= RAY_COMMON_DEBUG) - { - /* Print the local scheduler to plasma association map so far. */ - AuxAddressEntry *entry, *tmp; - LOG_DEBUG("Local scheduler to plasma hash map so far:"); - HASH_ITER(plasma_local_scheduler_hh, state->plasma_local_scheduler_map, - entry, tmp) { - LOG_DEBUG("%s -> %s", entry->aux_address, - ObjectID_to_string(entry->local_scheduler_db_client_id, - id_string, ID_STRING_SIZE)); + if (is_insertion) { + /* This is a notification for an insert. */ + add_local_scheduler(state, db_client_id, aux_address); + } else { + int i = 0; + for (; i < utarray_len(state->local_schedulers); ++i) { + LocalScheduler *active_worker = + (LocalScheduler *) utarray_eltptr(state->local_schedulers, i); + if (DBClientID_equal(active_worker->id, db_client_id)) { + break; + } + } + if (i < utarray_len(state->local_schedulers)) { + remove_local_scheduler(state, i); } } -#endif - - /* Add new local scheduler to the state. */ - LocalScheduler local_scheduler; - local_scheduler.id = db_client_id; - local_scheduler.num_tasks_sent = 0; - local_scheduler.num_recent_tasks_sent = 0; - local_scheduler.info.task_queue_length = 0; - local_scheduler.info.available_workers = 0; - memset(local_scheduler.info.dynamic_resources, 0, - sizeof(local_scheduler.info.dynamic_resources)); - memset(local_scheduler.info.static_resources, 0, - sizeof(local_scheduler.info.static_resources)); - utarray_push_back(state->local_schedulers, &local_scheduler); - - /* Allow the scheduling algorithm to process this event. */ - handle_new_local_scheduler(state, state->policy_state, db_client_id); } } @@ -312,6 +362,7 @@ void local_scheduler_table_handler(DBClientID client_id, LocalScheduler *local_scheduler_ptr = get_local_scheduler(state, client_id); if (local_scheduler_ptr != NULL) { /* Reset the number of tasks sent since the last heartbeat. */ + local_scheduler_ptr->num_heartbeats_missed = 0; local_scheduler_ptr->num_recent_tasks_sent = 0; local_scheduler_ptr->info = info; } else { @@ -335,6 +386,29 @@ int task_cleanup_handler(event_loop *loop, timer_id id, void *context) { free(*pending_task); } } + + /* Check for local schedulers that have missed a number of heartbeats. If any + * local schedulers have died, notify others so that the state can be cleaned + * up. */ + /* TODO(swang): If the local scheduler hasn't actually died, then it should + * clean up its state and exit upon receiving this notification. */ + LocalScheduler *local_scheduler_ptr; + for (int i = utarray_len(state->local_schedulers) - 1; i >= 0; --i) { + local_scheduler_ptr = + (LocalScheduler *) utarray_eltptr(state->local_schedulers, i); + if (local_scheduler_ptr->num_heartbeats_missed >= + GLOBAL_SCHEDULER_HEARTBEAT_TIMEOUT) { + LOG_WARN( + "Missed too many heartbeats from local scheduler, marking as dead."); + /* Notify others by updating the global state. */ + db_client_table_remove(state->db, local_scheduler_ptr->id, NULL, NULL, + NULL); + /* Remove the scheduler from the local state. */ + remove_local_scheduler(state, i); + } + ++local_scheduler_ptr->num_heartbeats_missed; + } + /* Reset the timer. */ return GLOBAL_SCHEDULER_TASK_CLEANUP_MILLISECONDS; } diff --git a/src/global_scheduler/global_scheduler.h b/src/global_scheduler/global_scheduler.h index 6bd0cc2bc..6f6e691ea 100644 --- a/src/global_scheduler/global_scheduler.h +++ b/src/global_scheduler/global_scheduler.h @@ -11,11 +11,18 @@ /* The frequency with which the global scheduler checks if there are any tasks * that haven't been scheduled yet. */ #define GLOBAL_SCHEDULER_TASK_CLEANUP_MILLISECONDS 100 +/* If a local scheduler has not sent a heartbeat in the last + * GLOBAL_SCHEDULER_HEARTBEAT_TIMEOUT heartbeat intervals, we will report it + * dead to the db_client table. */ +#define GLOBAL_SCHEDULER_HEARTBEAT_TIMEOUT 100 /** Contains all information that is associated with a local scheduler. */ typedef struct { /** The ID of the local scheduler in Redis. */ DBClientID id; + /** The number of heartbeat intervals that have passed since we last heard + * from this local scheduler. */ + int64_t num_heartbeats_missed; /** The number of tasks sent from the global scheduler to this local * scheduler. */ int64_t num_tasks_sent; diff --git a/src/global_scheduler/global_scheduler_algorithm.cc b/src/global_scheduler/global_scheduler_algorithm.cc index 852a2b61d..be31980d7 100644 --- a/src/global_scheduler/global_scheduler_algorithm.cc +++ b/src/global_scheduler/global_scheduler_algorithm.cc @@ -345,3 +345,9 @@ void handle_new_local_scheduler(GlobalSchedulerState *state, DBClientID db_client_id) { /* Do nothing for now. */ } + +void handle_local_scheduler_removed(GlobalSchedulerState *state, + GlobalSchedulerPolicyState *policy_state, + DBClientID db_client_id) { + /* Do nothing for now. */ +} diff --git a/src/global_scheduler/global_scheduler_algorithm.h b/src/global_scheduler/global_scheduler_algorithm.h index 2468d34ab..253582168 100644 --- a/src/global_scheduler/global_scheduler_algorithm.h +++ b/src/global_scheduler/global_scheduler_algorithm.h @@ -97,4 +97,8 @@ void handle_new_local_scheduler(GlobalSchedulerState *state, GlobalSchedulerPolicyState *policy_state, DBClientID db_client_id); +void handle_local_scheduler_removed(GlobalSchedulerState *state, + GlobalSchedulerPolicyState *policy_state, + DBClientID db_client_id); + #endif /* GLOBAL_SCHEDULER_ALGORITHM_H */ diff --git a/src/local_scheduler/local_scheduler_algorithm.cc b/src/local_scheduler/local_scheduler_algorithm.cc index dd1cf4c39..eef1ca941 100644 --- a/src/local_scheduler/local_scheduler_algorithm.cc +++ b/src/local_scheduler/local_scheduler_algorithm.cc @@ -526,7 +526,6 @@ bool can_run(SchedulingAlgorithmState *algorithm_state, task_spec *task) { return true; } -/* TODO(rkn): This method will need to be changed to call reconstruct. */ /* TODO(swang): This method is not covered by any valgrind tests. */ int fetch_object_timeout_handler(event_loop *loop, timer_id id, void *context) { LocalSchedulerState *state = (LocalSchedulerState *) context; diff --git a/src/local_scheduler/local_scheduler_client.cc b/src/local_scheduler/local_scheduler_client.cc index cd550418e..6488320ff 100644 --- a/src/local_scheduler/local_scheduler_client.cc +++ b/src/local_scheduler/local_scheduler_client.cc @@ -75,6 +75,7 @@ void local_scheduler_reconstruct_object(LocalSchedulerConnection *conn, ObjectID object_id) { write_message(conn->conn, RECONSTRUCT_OBJECT, sizeof(object_id), (uint8_t *) &object_id); + /* TODO(swang): Propagate the error. */ } void local_scheduler_log_message(LocalSchedulerConnection *conn) { diff --git a/src/plasma/plasma_store.cc b/src/plasma/plasma_store.cc index f493a97b3..d2e6963d4 100644 --- a/src/plasma/plasma_store.cc +++ b/src/plasma/plasma_store.cc @@ -618,6 +618,7 @@ void send_notifications(event_loop *loop, CHECK(queue != NULL); int num_processed = 0; + bool closed = false; /* Loop over the array of pending notifications and send as many of them as * possible. */ for (int i = 0; i < utarray_len(queue->object_notifications); ++i) { @@ -643,11 +644,24 @@ void send_notifications(event_loop *loop, break; } else { LOG_WARN("Failed to send notification to client on fd %d", client_sock); + if (errno == EPIPE) { + closed = true; + break; + } } num_processed += 1; } /* Remove the sent notifications from the array. */ utarray_erase(queue->object_notifications, 0, num_processed); + + /* Stop sending notifications if the pipe was broken. */ + if (closed) { + close(client_sock); + utarray_free(queue->object_notifications); + HASH_DEL(plasma_state->pending_notifications, queue); + free(queue); + } + /* If we have sent all notifications, remove the fd from the event loop. */ if (utarray_len(queue->object_notifications) == 0) { event_loop_remove_file(loop, client_sock); diff --git a/test/component_failures_test.py b/test/component_failures_test.py index f7d544908..12da5e293 100644 --- a/test/component_failures_test.py +++ b/test/component_failures_test.py @@ -103,5 +103,41 @@ class ComponentFailureTest(unittest.TestCase): def testWorkerFailedMultinode(self): self._testWorkerFailed(4) + def testNodeFailed(self): + @ray.remote + def f(x, j): + time.sleep(0.2) + return x + + # Start with 4 workers and 4 cores. + num_local_schedulers = 4 + num_workers_per_scheduler = 8 + address_info = ray.worker._init(num_workers=num_local_schedulers * num_workers_per_scheduler, + num_local_schedulers=num_local_schedulers, + start_ray_local=True, + num_cpus=[num_workers_per_scheduler] * num_local_schedulers) + + # Submit more tasks than there are workers so that all workers and cores are + # utilized. + object_ids = [f.remote(i, 0) for i in range(num_workers_per_scheduler * num_local_schedulers)] + object_ids += [f.remote(object_id, 1) for object_id in object_ids] + object_ids += [f.remote(object_id, 2) for object_id in object_ids] + + # Kill all nodes except the head node as the tasks execute. + time.sleep(0.1) + local_schedulers = ray.services.all_processes[ray.services.PROCESS_TYPE_LOCAL_SCHEDULER] + for process in local_schedulers[1:]: + process.terminate() + time.sleep(1) + + # Make sure that we can still get the objects after the executing tasks + # died. + results = ray.get(object_ids) + expected_results = 4 * list(range(num_workers_per_scheduler * num_local_schedulers)) + self.assertEqual(results, expected_results) + + ray.worker.cleanup() + + if __name__ == "__main__": unittest.main(verbosity=2)