diff --git a/python/ray/global_scheduler/test/test.py b/python/ray/global_scheduler/test/test.py index aeaff66fe..4c223a287 100644 --- a/python/ray/global_scheduler/test/test.py +++ b/python/ray/global_scheduler/test/test.py @@ -23,6 +23,7 @@ PLASMA_STORE_MEMORY = 1000000000 ID_SIZE = 20 NUM_CLUSTER_NODES = 2 +NIL_WORKER_ID = 20 * b"\xff" NIL_ACTOR_ID = 20 * b"\xff" # These constants must match the scheduling state enum in task.h. @@ -101,7 +102,7 @@ class TestGlobalScheduler(unittest.TestCase): static_resource_list=[10, 0]) # Connect to the scheduler. local_scheduler_client = local_scheduler.LocalSchedulerClient( - local_scheduler_name, NIL_ACTOR_ID, False) + local_scheduler_name, NIL_WORKER_ID, NIL_ACTOR_ID, False) self.local_scheduler_clients.append(local_scheduler_client) self.local_scheduler_pids.append(p4) diff --git a/python/ray/local_scheduler/test/test.py b/python/ray/local_scheduler/test/test.py index 89cd0be8c..f34e0673d 100644 --- a/python/ray/local_scheduler/test/test.py +++ b/python/ray/local_scheduler/test/test.py @@ -16,6 +16,7 @@ import ray.plasma as plasma USE_VALGRIND = False ID_SIZE = 20 +NIL_WORKER_ID = 20 * b"\xff" NIL_ACTOR_ID = 20 * b"\xff" @@ -47,7 +48,7 @@ class TestLocalSchedulerClient(unittest.TestCase): plasma_store_name, use_valgrind=USE_VALGRIND) # Connect to the scheduler. self.local_scheduler_client = local_scheduler.LocalSchedulerClient( - scheduler_name, NIL_ACTOR_ID, False) + scheduler_name, NIL_WORKER_ID, NIL_ACTOR_ID, False) def tearDown(self): # Check that the processes are still alive. diff --git a/python/ray/worker.py b/python/ray/worker.py index 7b514a820..be20030bd 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1400,7 +1400,8 @@ def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker, info["manager_socket_name"]) # Create the local scheduler client. worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient( - info["local_scheduler_socket_name"], worker.actor_id, is_worker) + info["local_scheduler_socket_name"], worker.worker_id, worker.actor_id, + is_worker) # If this is a driver, set the current task ID, the task driver ID, and set # the task index to 0. diff --git a/src/common/task.h b/src/common/task.h index 986717860..1bcc3cc7f 100644 --- a/src/common/task.h +++ b/src/common/task.h @@ -16,6 +16,7 @@ struct TaskBuilder; #define NIL_TASK_ID NIL_ID #define NIL_ACTOR_ID NIL_ID #define NIL_FUNCTION_ID NIL_ID +#define NIL_WORKER_ID NIL_ID typedef UniqueID FunctionID; diff --git a/src/local_scheduler/format/local_scheduler.fbs b/src/local_scheduler/format/local_scheduler.fbs index 4dadf696e..ccbf6ee1b 100644 --- a/src/local_scheduler/format/local_scheduler.fbs +++ b/src/local_scheduler/format/local_scheduler.fbs @@ -10,10 +10,12 @@ enum MessageType:int { // Log a message to the event table. This is sent from a worker to a local // scheduler. EventLogMessage, - // Send an initial connection message to the local scheduler. This is ent from - // a worker to a local scheduler. - // This contains the worker's process ID and actor ID. - RegisterWorkerInfo, + // Send an initial connection message to the local scheduler. This is sent + // from a worker or driver to a local scheduler. + RegisterClientRequest, + // Send a reply confirming the successful registration of a worker or driver. + // This is sent from the local scheduler to a worker or driver. + RegisterClientReply, // Get a new task from the local scheduler. This is sent from a worker to a // local scheduler. GetTask, @@ -31,6 +33,12 @@ enum MessageType:int { PutObject } +// This message is sent from the local scheduler to a worker. +table GetTaskReply { + // A string of bytes representing the task specification. + task_spec: string; +} + table EventLogMessage { key: string; value: string; @@ -38,14 +46,20 @@ table EventLogMessage { // This struct is used to register a new worker with the local scheduler. // It is shipped as part of local_scheduler_connect. -table RegisterWorkerInfo { - // The ID of the actor. - // This is NIL_ACTOR_ID if the worker is not an actor. +table RegisterClientRequest { + // True if the client is a worker and false if the client is a driver. + is_worker: bool; + // The ID of the worker or driver. + client_id: string; + // The ID of the actor. This is NIL_ACTOR_ID if the worker is not an actor. actor_id: string; // The process ID of this worker. worker_pid: long; } +table RegisterClientReply { +} + table ReconstructObject { // Object ID of the object that needs to be reconstructed. object_id: string; diff --git a/src/local_scheduler/local_scheduler.cc b/src/local_scheduler/local_scheduler.cc index ef837a7a8..0149e0508 100644 --- a/src/local_scheduler/local_scheduler.cc +++ b/src/local_scheduler/local_scheduler.cc @@ -458,8 +458,14 @@ void assign_task_to_worker(LocalSchedulerState *state, TaskSpec *spec, int64_t task_spec_size, LocalSchedulerClient *worker) { - if (write_message(worker->sock, MessageType_ExecuteTask, task_spec_size, - (uint8_t *) spec) < 0) { + /* Construct a flatbuffer object to send to the worker. */ + flatbuffers::FlatBufferBuilder fbb; + auto message = + CreateGetTaskReply(fbb, fbb.CreateString((char *) spec, task_spec_size)); + fbb.Finish(message); + + if (write_message(worker->sock, MessageType_ExecuteTask, fbb.GetSize(), + (uint8_t *) fbb.GetBufferPointer()) < 0) { if (errno == EPIPE || errno == EBADF) { /* TODO(rkn): If this happens, the task should be added back to the task * queue. */ @@ -646,6 +652,80 @@ void reconstruct_object(LocalSchedulerState *state, reconstruct_object_lookup_callback, (void *) state); } +void send_client_register_reply(LocalSchedulerState *state, + LocalSchedulerClient *worker) { + flatbuffers::FlatBufferBuilder fbb; + auto message = CreateRegisterClientReply(fbb); + fbb.Finish(message); + + /* Send the message to the client. */ + if (write_message(worker->sock, MessageType_RegisterClientReply, + fbb.GetSize(), fbb.GetBufferPointer()) < 0) { + if (errno == EPIPE || errno == EBADF || errno == ECONNRESET) { + /* Something went wrong, so kill the worker. */ + kill_worker(worker, false); + LOG_WARN( + "Failed to give send register client reply to worker on fd %d. The " + "client may have hung up.", + worker->sock); + } else { + LOG_FATAL("Failed to send register client reply to client on fd %d.", + worker->sock); + } + } +} + +void handle_client_register(LocalSchedulerState *state, + LocalSchedulerClient *worker, + const RegisterClientRequest *message) { + /* Register the worker or driver. */ + if (message->is_worker()) { + /* Update the actor mapping with the actor ID of the worker (if an actor is + * running on the worker). */ + int64_t worker_pid = message->worker_pid(); + ActorID actor_id = from_flatbuf(message->actor_id()); + if (!ActorID_equal(actor_id, NIL_ACTOR_ID)) { + /* Make sure that the local scheduler is aware that it is responsible for + * this actor. */ + actor_map_entry *entry; + HASH_FIND(hh, state->actor_mapping, &actor_id, sizeof(actor_id), entry); + CHECK(entry != NULL); + CHECK(DBClientID_equal(entry->local_scheduler_id, + get_db_client_id(state->db))); + /* Update the worker struct with this actor ID. */ + CHECK(ActorID_equal(worker->actor_id, NIL_ACTOR_ID)); + worker->actor_id = actor_id; + /* Let the scheduling algorithm process the presence of this new + * worker. */ + handle_actor_worker_connect(state, state->algorithm_state, actor_id, + worker); + } + + /* Register worker process id with the scheduler. */ + worker->pid = worker_pid; + /* Determine if this worker is one of our child processes. */ + LOG_DEBUG("PID is %d", worker_pid); + pid_t *child_pid; + int index = 0; + for (child_pid = (pid_t *) utarray_front(state->child_pids); + child_pid != NULL; + child_pid = (pid_t *) utarray_next(state->child_pids, child_pid)) { + if (*child_pid == worker_pid) { + /* If this worker is one of our child processes, mark it as a child so + * that we know that we can wait for the process to exit during + * cleanup. */ + worker->is_child = true; + utarray_erase(state->child_pids, index, 1); + LOG_DEBUG("Found matching child pid %d", worker_pid); + break; + } + ++index; + } + } else { + /* Register the driver. Currently we don't do anything here. */ + } +} + void process_message(event_loop *loop, int client_sock, void *context, @@ -692,50 +772,11 @@ void process_message(event_loop *loop, (uint8_t *) message->value()->data(), message->value()->size()); } } break; - case MessageType_RegisterWorkerInfo: { - /* Update the actor mapping with the actor ID of the worker (if an actor is - * running on the worker). */ - auto message = flatbuffers::GetRoot( + case MessageType_RegisterClientRequest: { + auto message = flatbuffers::GetRoot( utarray_front(state->input_buffer)); - int64_t worker_pid = message->worker_pid(); - ActorID actor_id = from_flatbuf(message->actor_id()); - if (!ActorID_equal(actor_id, NIL_ACTOR_ID)) { - /* Make sure that the local scheduler is aware that it is responsible for - * this actor. */ - actor_map_entry *entry; - HASH_FIND(hh, state->actor_mapping, &actor_id, sizeof(actor_id), entry); - CHECK(entry != NULL); - CHECK(DBClientID_equal(entry->local_scheduler_id, - get_db_client_id(state->db))); - /* Update the worker struct with this actor ID. */ - CHECK(ActorID_equal(worker->actor_id, NIL_ACTOR_ID)); - worker->actor_id = actor_id; - /* Let the scheduling algorithm process the presence of this new - * worker. */ - handle_actor_worker_connect(state, state->algorithm_state, actor_id, - worker); - } - - /* Register worker process id with the scheduler. */ - worker->pid = worker_pid; - /* Determine if this worker is one of our child processes. */ - LOG_DEBUG("PID is %d", worker_pid); - pid_t *child_pid; - int index = 0; - for (child_pid = (pid_t *) utarray_front(state->child_pids); - child_pid != NULL; - child_pid = (pid_t *) utarray_next(state->child_pids, child_pid)) { - if (*child_pid == worker_pid) { - /* If this worker is one of our child processes, mark it as a child so - * that we know that we can wait for the process to exit during - * cleanup. */ - worker->is_child = true; - utarray_erase(state->child_pids, index, 1); - LOG_DEBUG("Found matching child pid %d", worker_pid); - break; - } - ++index; - } + handle_client_register(state, worker, message); + send_client_register_reply(state, worker); } break; case MessageType_GetTask: { /* If this worker reports a completed task: account for resources. */ diff --git a/src/local_scheduler/local_scheduler_client.cc b/src/local_scheduler/local_scheduler_client.cc index 61d3e5a47..ca8955bfd 100644 --- a/src/local_scheduler/local_scheduler_client.cc +++ b/src/local_scheduler/local_scheduler_client.cc @@ -9,25 +9,40 @@ LocalSchedulerConnection *LocalSchedulerConnection_init( const char *local_scheduler_socket, + UniqueID client_id, ActorID actor_id, bool is_worker) { LocalSchedulerConnection *result = (LocalSchedulerConnection *) malloc(sizeof(LocalSchedulerConnection)); result->conn = connect_ipc_sock_retry(local_scheduler_socket, -1, -1); - if (is_worker) { - /* If we are a worker, register with the local scheduler. - * NOTE(swang): If the local scheduler exits and we are registered as a - * worker, we will get killed. */ - flatbuffers::FlatBufferBuilder fbb; - auto message = - CreateRegisterWorkerInfo(fbb, to_flatbuf(fbb, actor_id), getpid()); - fbb.Finish(message); - /* Register the process ID with the local scheduler. */ - int success = write_message(result->conn, MessageType_RegisterWorkerInfo, - fbb.GetSize(), fbb.GetBufferPointer()); - CHECKM(success == 0, "Unable to register worker with local scheduler"); + /* Register with the local scheduler. + * NOTE(swang): If the local scheduler exits and we are registered as a + * worker, we will get killed. */ + flatbuffers::FlatBufferBuilder fbb; + auto message = + CreateRegisterClientRequest(fbb, is_worker, to_flatbuf(fbb, client_id), + to_flatbuf(fbb, actor_id), getpid()); + fbb.Finish(message); + /* Register the process ID with the local scheduler. */ + int success = write_message(result->conn, MessageType_RegisterClientRequest, + fbb.GetSize(), fbb.GetBufferPointer()); + CHECKM(success == 0, "Unable to register worker with local scheduler"); + + /* Wait for a confirmation from the local scheduler. */ + int64_t type; + int64_t reply_size; + uint8_t *reply; + read_message(result->conn, &type, &reply_size, &reply); + if (type == DISCONNECT_CLIENT) { + LOG_WARN("Exiting because local scheduler closed connection."); + exit(1); } + CHECK(type == MessageType_RegisterClientReply); + + /* Parse the reply object. We currently don't do anything with it. */ + auto reply_message = flatbuffers::GetRoot(reply); + free(reply); return result; } @@ -62,13 +77,27 @@ TaskSpec *local_scheduler_get_task(LocalSchedulerConnection *conn, int64_t *task_size) { write_message(conn->conn, MessageType_GetTask, 0, NULL); int64_t type; + int64_t message_size; uint8_t *message; /* Receive a task from the local scheduler. This will block until the local * scheduler gives this client a task. */ - read_message(conn->conn, &type, task_size, &message); + read_message(conn->conn, &type, &message_size, &message); + if (type == DISCONNECT_CLIENT) { + LOG_WARN("Exiting because local scheduler closed connection."); + exit(1); + } CHECK(type == MessageType_ExecuteTask); - TaskSpec *task = (TaskSpec *) message; - return task; + + /* Parse the flatbuffer object. */ + auto reply_message = flatbuffers::GetRoot(message); + /* Create a copy of the task spec so we can free the reply. */ + *task_size = reply_message->task_spec()->size(); + TaskSpec *spec = (TaskSpec *) malloc(*task_size); + memcpy(spec, reply_message->task_spec()->data(), *task_size); + /* Free the original message from the local scheduler. */ + free(message); + /* Return the copy of the task spec and pass ownership to the caller. */ + return spec; } void local_scheduler_task_done(LocalSchedulerConnection *conn) { diff --git a/src/local_scheduler/local_scheduler_client.h b/src/local_scheduler/local_scheduler_client.h index b05341712..8e63a70f7 100644 --- a/src/local_scheduler/local_scheduler_client.h +++ b/src/local_scheduler/local_scheduler_client.h @@ -23,6 +23,7 @@ typedef struct { */ LocalSchedulerConnection *LocalSchedulerConnection_init( const char *local_scheduler_socket, + UniqueID worker_id, ActorID actor_id, bool is_worker); diff --git a/src/local_scheduler/local_scheduler_extension.cc b/src/local_scheduler/local_scheduler_extension.cc index 46d133571..e9f399ea4 100644 --- a/src/local_scheduler/local_scheduler_extension.cc +++ b/src/local_scheduler/local_scheduler_extension.cc @@ -17,21 +17,25 @@ static int PyLocalSchedulerClient_init(PyLocalSchedulerClient *self, PyObject *args, PyObject *kwds) { char *socket_name; + UniqueID client_id; ActorID actor_id; PyObject *is_worker; - if (!PyArg_ParseTuple(args, "sO&O", &socket_name, PyStringToUniqueID, - &actor_id, &is_worker)) { + self->local_scheduler_connection = NULL; + if (!PyArg_ParseTuple(args, "sO&O&O", &socket_name, PyStringToUniqueID, + &client_id, PyStringToUniqueID, &actor_id, + &is_worker)) { return -1; } /* Connect to the local scheduler. */ self->local_scheduler_connection = LocalSchedulerConnection_init( - socket_name, actor_id, (bool) PyObject_IsTrue(is_worker)); + socket_name, client_id, actor_id, (bool) PyObject_IsTrue(is_worker)); return 0; } static void PyLocalSchedulerClient_dealloc(PyLocalSchedulerClient *self) { - LocalSchedulerConnection_free( - ((PyLocalSchedulerClient *) self)->local_scheduler_connection); + if (self->local_scheduler_connection != NULL) { + LocalSchedulerConnection_free(self->local_scheduler_connection); + } Py_TYPE(self)->tp_free((PyObject *) self); } diff --git a/src/local_scheduler/test/local_scheduler_tests.cc b/src/local_scheduler/test/local_scheduler_tests.cc index 225e9740c..a7247a2c1 100644 --- a/src/local_scheduler/test/local_scheduler_tests.cc +++ b/src/local_scheduler/test/local_scheduler_tests.cc @@ -7,6 +7,8 @@ #include #include +#include + #include "common.h" #include "test/test_common.h" #include "test/example_task.h" @@ -54,6 +56,23 @@ typedef struct { LocalSchedulerConnection **conns; } LocalSchedulerMock; +/** + * Register clients of the local scheduler. This function is started in a + * separate thread so enable a blocking call to register the clients. + */ +static void register_clients(int num_mock_workers, LocalSchedulerMock *mock) { + for (int i = 0; i < num_mock_workers; ++i) { + new_client_connection(mock->loop, mock->local_scheduler_fd, + (void *) mock->local_scheduler_state, 0); + + LocalSchedulerClient **worker = (LocalSchedulerClient **) utarray_eltptr( + mock->local_scheduler_state->workers, i); + + process_message(mock->local_scheduler_state->loop, (*worker)->sock, *worker, + 0); + } +} + LocalSchedulerMock *LocalSchedulerMock_init(int num_workers, int num_mock_workers) { const char *node_ip_address = "127.0.0.1"; @@ -101,13 +120,18 @@ LocalSchedulerMock *LocalSchedulerMock_init(int num_workers, mock->num_local_scheduler_conns = num_mock_workers; mock->conns = (LocalSchedulerConnection **) malloc( sizeof(LocalSchedulerConnection *) * num_mock_workers); + + std::thread background_thread = + std::thread(register_clients, num_mock_workers, mock); + for (int i = 0; i < num_mock_workers; ++i) { mock->conns[i] = LocalSchedulerConnection_init( - utstring_body(local_scheduler_socket_name), NIL_ACTOR_ID, true); - new_client_connection(mock->loop, mock->local_scheduler_fd, - (void *) mock->local_scheduler_state, 0); + utstring_body(local_scheduler_socket_name), NIL_WORKER_ID, NIL_ACTOR_ID, + true); } + background_thread.join(); + utstring_free(worker_command); utstring_free(plasma_manager_socket_name); utstring_free(local_scheduler_socket_name);