From a1e4268d37d737637a6615ff24ef266b166c348a Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Sun, 7 Aug 2016 13:53:33 -0700 Subject: [PATCH] Catch errors in importing reusable variables and remote functions (#354) * catch errors in importing reusable variables and remote functions * updates --- lib/python/ray/__init__.py | 4 +- lib/python/ray/worker.py | 223 +++++++++++++++++++------------------ protos/ray.proto | 63 ++++++----- protos/types.proto | 21 +++- src/raylib.cc | 90 ++++++++++++--- src/scheduler.cc | 147 ++++++++++++++++-------- src/scheduler.h | 13 ++- src/worker.cc | 113 +++++++++++++------ src/worker.h | 42 ++++--- test/runtest.py | 60 +++++++++- 10 files changed, 512 insertions(+), 264 deletions(-) diff --git a/lib/python/ray/__init__.py b/lib/python/ray/__init__.py index 34cb3d837..e01643631 100644 --- a/lib/python/ray/__init__.py +++ b/lib/python/ray/__init__.py @@ -11,8 +11,8 @@ if hasattr(ctypes, "windll"): import config import serialization -from worker import scheduler_info, visualize_computation_graph, task_info, register_module, init, connect, disconnect, get, put, remote, kill_workers, restart_workers_local +from worker import scheduler_info, visualize_computation_graph, task_info, init, connect, disconnect, get, put, remote, kill_workers, restart_workers_local from worker import Reusable, reusables -from worker import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE +from libraylib import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE from libraylib import ObjectID import internal diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index eb0e05672..fae66f251 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -20,14 +20,6 @@ import services import libnumbuf import libraylib as raylib -# These three constants are used to define the mode that a worker is running in. -# Right now, this is only used for determining how to print information about -# task failures. -SCRIPT_MODE = 0 -WORKER_MODE = 1 -PYTHON_MODE = 2 -SILENT_MODE = 3 # This is only used during testing. - class RayTaskError(Exception): """An object used internally to represent a task that threw an exception. @@ -341,7 +333,7 @@ class RayReusables(object): raise Exception("To set a reusable variable, you must pass in a Reusable object") self._names.add(name) self._reusables[name] = reusable - if _mode() in [SCRIPT_MODE, SILENT_MODE]: + if _mode() in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]: _export_reusable_variable(name, reusable) elif _mode() is None: self._cached_reusables.append((name, reusable)) @@ -369,11 +361,13 @@ class Worker(object): handle (worker capsule): A Python object wrapping a C++ Worker object. mode: The mode of the worker. One of SCRIPT_MODE, PYTHON_MODE, SILENT_MODE, and WORKER_MODE. - cached_remote_functions (List[str]): A list of serialized remote functions - that were defined before the worker called connect. When the worker + cached_remote_functions (List[Tuple[str, str]]): A list of pairs + representing the remote functions that were defined before he worker + called connect. The first element is the name of the remote function, and + the second element is the serialized remote function. When the worker eventually does call connect, if it is a driver, it will export these functions to the scheduler. If cached_remote_functions is None, that means - that connect has been called. + that connect has been called already. num_failed_tasks (int): The number of tasks that have failed and whose error messages have been displayed to the user. We use this value to know when a failed task hasn't been seen by the user and should be displayed. @@ -507,20 +501,6 @@ class Worker(object): """Make two object IDs refer to the same object.""" raylib.alias_objectids(self.handle, alias_objectid, target_objectid) - def register_function(self, function): - """Register a function with the scheduler. - - Notify the scheduler that this worker can execute the function with name - func_name. After this call, the scheduler can send tasks for executing - the function to this worker. - - Args: - function (Callable): The remote function that this worker can execute. - """ - _logger().info("Registering function {}.".format(function.func_name)) - raylib.register_function(self.handle, function.func_name, len(function.return_types)) - self.functions[function.func_name] = function - def submit_task(self, func_name, args): """Submit a remote task to the scheduler. @@ -536,23 +516,8 @@ class Worker(object): """ task_capsule = serialization.serialize_task(self.handle, func_name, args) objectids = raylib.submit_task(self.handle, task_capsule) - if self.mode == SCRIPT_MODE: - self.print_new_failures() return objectids - def print_new_failures(self): - """Print information about tasks.""" - task_data = raylib.task_info(self.handle) - num_tasks_succeeded = task_data["num_succeeded"] - num_tasks_in_progress = len(task_data["running_tasks"]) - num_new_tasks_failed = len(task_data["failed_tasks"]) - self.num_failed_tasks - if num_new_tasks_failed > 0: - # Print the new tasks that have failed. - for task_status in task_data["failed_tasks"][self.num_failed_tasks:]: - print_failed_task(task_status) - print "{}Error: {} new task{} failed.{}".format(colorama.Fore.RED, num_new_tasks_failed, "s" if num_new_tasks_failed > 1 else "", colorama.Fore.RESET) - self.num_failed_tasks = len(task_data["failed_tasks"]) - global_worker = Worker() """Worker: The global Worker object for this worker process. @@ -643,24 +608,7 @@ def task_info(worker=global_worker): check_connected(worker) return raylib.task_info(worker.handle) -def register_module(module, worker=global_worker): - """Register each remote function in a module with the scheduler. - - This registers each remote function in the module with the scheduler, so tasks - with those functions can be scheduled on this worker. - - args: - module (module): The module of functions to register. - """ - check_connected(worker) - _logger().info("registering functions in module {}.".format(module.__name__)) - for name in dir(module): - val = getattr(module, name) - if hasattr(val, "is_remote") and val.is_remote: - _logger().info("registering {}.".format(val.func_name)) - worker.register_function(val) - -def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_address=None, node_ip_address=None, driver_mode=SCRIPT_MODE): +def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_address=None, node_ip_address=None, driver_mode=raylib.SCRIPT_MODE): """Either connect to an existing Ray cluster or start one and connect to it. This method handles two cases. Either a Ray cluster already exists and we @@ -692,8 +640,8 @@ def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_ # and we connect to them. if (scheduler_address is not None) or (node_ip_address is not None): raise Exception("If start_ray_local=True, then you cannot pass in a scheduler_address or a node_ip_address.") - if driver_mode not in [SCRIPT_MODE, PYTHON_MODE, SILENT_MODE]: - raise Exception("If start_ray_local=True, then driver_mode must be in [SCRIPT_MODE, PYTHON_MODE, SILENT_MODE].") + if driver_mode not in [raylib.SCRIPT_MODE, raylib.PYTHON_MODE, raylib.SILENT_MODE]: + raise Exception("If start_ray_local=True, then driver_mode must be in [ray.SCRIPT_MODE, ray.PYTHON_MODE, ray.SILENT_MODE].") # Use the address 127.0.0.1 in local mode. node_ip_address = "127.0.0.1" num_workers = 1 if num_workers is None else num_workers @@ -727,7 +675,7 @@ def cleanup(worker=global_worker): atexit.register(cleanup) -def connect(node_ip_address, scheduler_address, objstore_address=None, is_driver=False, worker=global_worker, mode=WORKER_MODE): +def connect(node_ip_address, scheduler_address, objstore_address=None, is_driver=False, worker=global_worker, mode=raylib.WORKER_MODE): """Connect this worker to the scheduler and an object store. Args: @@ -757,13 +705,19 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, is_driver _logger().propagate = False # Configure the logging from the worker C++ code. raylib.set_log_config(config.get_log_file_path("-".join(["worker", worker.worker_address, "c++"]) + ".log")) - if mode in [SCRIPT_MODE, SILENT_MODE]: - for function_to_export in worker.cached_remote_functions: - raylib.export_function(worker.handle, function_to_export) + if mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]: + for function_name, function_to_export in worker.cached_remote_functions: + raylib.export_remote_function(worker.handle, function_name, function_to_export) for name, reusable_variable in reusables._cached_reusables: _export_reusable_variable(name, reusable_variable) worker.cached_remote_functions = None reusables._cached_reusables = None + # Start the driver's WorkerService (if this is a driver). This will receive + # GRPC commands from the scheduler to print error messages. We pass in the + # mode below. This tells the WorkerService whether it is operating for a + # driver or a worker and whether it should surpress errors or not. + if is_driver: + raylib.start_worker_service(worker.handle, mode) def disconnect(worker=global_worker): """Disconnect this worker from the scheduler and object store.""" @@ -791,11 +745,9 @@ def get(objectid, worker=global_worker): A Python object """ check_connected(worker) - if worker.mode == PYTHON_MODE: - return objectid # In PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid) + if worker.mode == raylib.PYTHON_MODE: + return objectid # In raylib.PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid) raylib.request_object(worker.handle, objectid) - if worker.mode == SCRIPT_MODE: - worker.print_new_failures() value = worker.get_object(objectid) if isinstance(value, RayTaskError): # If the result is a RayTaskError, then the task that created this object @@ -813,12 +765,10 @@ def put(value, worker=global_worker): The object ID assigned to this value. """ check_connected(worker) - if worker.mode == PYTHON_MODE: - return value # In PYTHON_MODE, ray.put is the identity operation + if worker.mode == raylib.PYTHON_MODE: + return value # In raylib.PYTHON_MODE, ray.put is the identity operation objectid = raylib.get_objectid(worker.handle) worker.put_object(objectid, value) - if worker.mode == SCRIPT_MODE: - worker.print_new_failures() return objectid def kill_workers(worker=global_worker): @@ -878,27 +828,33 @@ def format_error_message(exception_message): def main_loop(worker=global_worker): """The main loop a worker runs to receive and execute tasks. - This method is an infinite loop. It waits to receive tasks from the scheduler. - When it receives a task, it first deserializes the task. Then it retrieves the - values for any arguments that were passed in as object IDs. Then it - passes the arguments to the actual function. Then it stores the outputs of the - function in the local object store. Then it notifies the scheduler that it - completed the task. - - If the process of getting the arguments for execution (which does some type - checking) or the process of executing the task fail, then the main loop will - catch the exception and store RayTaskError objects containing the relevant - error messages in the object store in place of the actual outputs. These - objects are used to propagate the error messages. + This method is an infinite loop. It waits to receive commands from the + scheduler. A command may consist of a task to execute, a remote function to + import, a reusable variable to import, or an order to terminate the worker + process. The worker executes the command, notifies the scheduler of any errors + that occurred while executing the command, and waits for the next command. """ if not raylib.connected(worker.handle): raise Exception("Worker is attempting to enter main_loop but has not been connected yet.") - raylib.start_worker_service(worker.handle) + # We pass in raylib.WORKER_MODE below to indicate that the WorkerService is + # operating for a worker and not a driver. + raylib.start_worker_service(worker.handle, raylib.WORKER_MODE) + def process_task(task): # wrapping these lines in a function should cause the local variables to go out of scope more quickly, which is useful for inspecting reference counts - func_name, args, return_objectids = serialization.deserialize_task(worker.handle, task) + """Execute a task assigned to this worker. + + This method deserializes a task from the scheduler, and attempts to execute + the task. If the task succeeds, the outputs are stored in the local object + store. If the task throws an exception, RayTaskError objects are stored in + the object store to represent the failed task (these will be retrieved by + calls to get or by subsequent tasks that use the outputs of this task). + After the task executes, the worker resets any reusable variables that were + accessed by the task. + """ + function_name, args, return_objectids = serialization.deserialize_task(worker.handle, task) try: - arguments = get_arguments_for_execution(worker.functions[func_name], args, worker) # get args from objstore - outputs = worker.functions[func_name].executor(arguments) # execute the function + arguments = get_arguments_for_execution(worker.functions[function_name], args, worker) # get args from objstore + outputs = worker.functions[function_name].executor(arguments) # execute the function if len(return_objectids) == 1: outputs = (outputs,) except Exception as e: @@ -906,37 +862,82 @@ def main_loop(worker=global_worker): # whether the exception was thrown in the task execution by whether the # variable "arguments" is defined. traceback_str = format_error_message(traceback.format_exc()) if "arguments" in locals() else None - failure_object = RayTaskError(func_name, e, traceback_str) + failure_object = RayTaskError(function_name, e, traceback_str) failure_objects = [failure_object for _ in range(len(return_objectids))] store_outputs_in_objstore(return_objectids, failure_objects, worker) - raylib.notify_task_completed(worker.handle, False, str(failure_object)) - _logger().info("Worker threw exception with message: \n\n{}\n, while running function {}.".format(str(failure_object), func_name)) + # Notify the scheduler that the task failed. + raylib.notify_failure(worker.handle, function_name, str(failure_object), raylib.FailedTask) + _logger().info("While running function {}, worker threw exception with message: \n\n{}\n".format(function_name, str(failure_object))) else: store_outputs_in_objstore(return_objectids, outputs, worker) # store output in local object store - raylib.notify_task_completed(worker.handle, True, "") # notify the scheduler that the task completed successfully - finally: + # Notify the scheduler that the task is done. This happens regardless of + # whether the task succeeded or failed. + raylib.notify_task_completed(worker.handle) + try: # Reinitialize the values of reusable variables that were used in the task # above so that changes made to their state do not affect other tasks. reusables._reinitialize() + except Exception as e: + # The attempt to reinitialize the reusable variables threw an exception. + # We record the traceback and notify the scheduler. + traceback_str = format_error_message(traceback.format_exc()) + raylib.notify_failure(worker.handle, function_name, traceback_str, raylib.FailedReinitializeReusableVariable) + _logger().info("While attempting to reinitialize the reusable variables after running function {}, the worker threw exception with message: \n\n{}\n".format(function_name, traceback_str)) + + def process_remote_function(function_name, serialized_function): + """Import a remote function.""" + try: + (function, arg_types, return_types, module) = pickling.loads(serialized_function) + except: + # If an exception was thrown when the remote function was imported, we + # record the traceback and notify the scheduler of the failure. + traceback_str = format_error_message(traceback.format_exc()) + _logger().info("Failed to import remote function {}. Failed with message: \n\n{}\n".format(function_name, traceback_str)) + # Notify the scheduler that the remote function failed to import. + raylib.notify_failure(worker.handle, function_name, traceback_str, raylib.FailedRemoteFunctionImport) + else: + # TODO(rkn): Why is the below line necessary? + function.__module__ = module + assert function_name == "{}.{}".format(function.__module__, function.__name__), "The remote function name does not match the name that was passed in." + worker.functions[function_name] = remote(arg_types, return_types, worker)(function) + _logger().info("Successfully imported remote function {}.".format(function_name)) + # Noify the scheduler that the remote function imported successfully. + # We pass an empty error message string because the import succeeded. + raylib.register_remote_function(worker.handle, function_name, len(return_types)) + + def process_reusable_variable(reusable_variable_name, initializer_str, reinitializer_str): + """Import a reusable variable.""" + try: + initializer = pickling.loads(initializer_str) + reinitializer = pickling.loads(reinitializer_str) + reusables.__setattr__(reusable_variable_name, Reusable(initializer, reinitializer)) + except: + # If an exception was thrown when the reusable variable was imported, we + # record the traceback and notify the scheduler of the failure. + traceback_str = format_error_message(traceback.format_exc()) + _logger().info("Failed to import reusable variable {}. Failed with message: \n\n{}\n".format(reusable_variable_name, traceback_str)) + # Notify the scheduler that the reusable variable failed to import. + raylib.notify_failure(worker.handle, reusable_variable_name, traceback_str, raylib.FailedReusableVariableImport) + else: + _logger().info("Successfully imported reusable variable {}.".format(reusable_variable_name)) + while True: command, command_args = raylib.wait_for_next_message(worker.handle) try: if command == "die": # We use this as a mechanism to allow the scheduler to kill workers. + _logger().info("Received a 'die' command, and will exit now.") break - elif command == "function": - (function, arg_types, return_types, module) = pickling.loads(command_args) - # TODO(rkn): Why is the below line necessary? - function.__module__ = module - worker.register_function(remote(arg_types, return_types, worker)(function)) - elif command == "reusable_variable": - name, initializer_str, reinitializer_str = command_args - initializer = pickling.loads(initializer_str) - reinitializer = pickling.loads(reinitializer_str) - reusables.__setattr__(name, Reusable(initializer, reinitializer)) elif command == "task": process_task(command_args) + elif command == "function": + function_name, serialized_function = command_args + process_remote_function(function_name, serialized_function) + elif command == "reusable_variable": + name, initializer_str, reinitializer_str = command_args + process_reusable_variable(name, initializer_str, reinitializer_str) else: + _logger().info("Reached the end of the if-else loop in the main loop. This should be unreachable.") assert False, "This code should be unreachable." finally: # Allow releasing the variables BEFORE we wait for the next message or exit the block @@ -979,7 +980,7 @@ def _export_reusable_variable(name, reusable, worker=global_worker): reusable (Reusable): The reusable object containing code for initializing and reinitializing the variable. """ - if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]: + if _mode(worker) not in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]: raise Exception("_export_reusable_variable can only be called on a driver.") raylib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer)) @@ -996,8 +997,8 @@ def remote(arg_types, return_types, worker=global_worker): check_connected() args = list(args) args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments - if _mode() == PYTHON_MODE: - # In PYTHON_MODE, remote calls simply execute the function. We copy the + if _mode() == raylib.PYTHON_MODE: + # In raylib.PYTHON_MODE, remote calls simply execute the function. We copy the # arguments to prevent the function call from mutating them and to match # the usual behavior of immutable remote objects. return func(*copy.deepcopy(args)) @@ -1035,7 +1036,7 @@ def remote(arg_types, return_types, worker=global_worker): check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name) # Everything ready - export the function - if worker.mode in [None, SCRIPT_MODE, SILENT_MODE]: + if worker.mode in [None, raylib.SCRIPT_MODE, raylib.SILENT_MODE]: func_name_global_valid = func.__name__ in func.__globals__ func_name_global_value = func.__globals__.get(func.__name__) # Set the function globally to make it refer to itself @@ -1046,10 +1047,10 @@ def remote(arg_types, return_types, worker=global_worker): # Undo our changes if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value else: del func.__globals__[func.__name__] - if worker.mode in [SCRIPT_MODE, SILENT_MODE]: - raylib.export_function(worker.handle, to_export) + if worker.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]: + raylib.export_remote_function(worker.handle, func_name, to_export) elif worker.mode is None: - worker.cached_remote_functions.append(to_export) + worker.cached_remote_functions.append((func_name, to_export)) return func_invoker return remote_decorator diff --git a/protos/ray.proto b/protos/ray.proto index 6906c5cf2..98b681aad 100644 --- a/protos/ray.proto +++ b/protos/ray.proto @@ -22,8 +22,8 @@ service Scheduler { rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerReply); // Register an object store with the scheduler rpc RegisterObjStore(RegisterObjStoreRequest) returns (RegisterObjStoreReply); - // Tell the scheduler that a worker can execute a certain function - rpc RegisterFunction(RegisterFunctionRequest) returns (AckReply); + // Tell the scheduler that a worker successfully imported a remote function. + rpc RegisterRemoteFunction(RegisterRemoteFunctionRequest) returns (AckReply); // Asks the scheduler to execute a task, immediately returns an object ID to the result rpc SubmitTask(SubmitTaskRequest) returns (SubmitTaskReply); // Increment the count of the object ID @@ -53,9 +53,11 @@ service Scheduler { // Kills the workers rpc KillWorkers(KillWorkersRequest) returns (KillWorkersReply); // Exports function to the workers - rpc ExportFunction(ExportFunctionRequest) returns (ExportFunctionReply); + rpc ExportRemoteFunction(ExportRemoteFunctionRequest) returns (AckReply); // Ship an initializer and reinitializer for a reusable variable to the workers rpc ExportReusableVariable(ExportReusableVariableRequest) returns (AckReply); + // Notify the scheduler that a failure occurred while running a task, importing a remote function, or importing a reusable variable. + rpc NotifyFailure(NotifyFailureRequest) returns (AckReply); } message AckReply { @@ -82,10 +84,14 @@ message RegisterObjStoreReply { uint64 objstoreid = 1; // Object store ID assigned by the scheduler } -message RegisterFunctionRequest { +message RegisterRemoteFunctionRequest { uint64 workerid = 1; // Worker that can execute the function - string fnname = 2; // Name of the function that is registered - uint64 num_return_vals = 3; // Number of return values of the function + string function_name = 2; // Name of the remote function + uint64 num_return_vals = 3; // Number of return values of the function. This is only present if the function was successfully imported. +} + +message NotifyFailure { + Failure failure = 1; // The failure object. } message SubmitTaskRequest { @@ -136,11 +142,6 @@ message DecrementRefCountRequest { message ReadyForNewTaskRequest { uint64 workerid = 1; // ID of the worker which executed the task - message PreviousTaskInfo { - bool task_succeeded = 1; // True if the task succeeded, false if it threw an exception - string error_message = 2; // The contents of the exception, if the task threw an exception - } - PreviousTaskInfo previous_task_info = 2; // Information about the previous task, this is only present if there was a previous task } message ChangeCountRequest { @@ -221,10 +222,11 @@ message TaskInfoRequest { } message TaskInfoReply { - repeated TaskStatus failed_task = 1; - repeated TaskStatus running_task = 2; - uint64 num_succeeded = 3; - // TODO(mehrdadn): We'll want to return information from computation_graph since it's important for visualizing tasks that have been completed etc. + repeated TaskStatus failed_task = 1; // The tasks that have failed. + repeated TaskStatus running_task = 2; // The tasks that are currently running. + repeated Failure failed_remote_function_import = 3; // The remote function imports that failed. + repeated Failure failed_reusable_variable_import = 4; // The reusable variable imports that failed. + repeated Failure failed_reinitialize_reusable_variable = 5; // The reusable variable reinitializations that failed. } message KillWorkersRequest { @@ -234,17 +236,18 @@ message KillWorkersReply { bool success = 1; // Currently, the only reason to fail is if there are workers still executing tasks } -message ExportFunctionRequest { +message ExportRemoteFunctionRequest { Function function = 1; } -message ExportFunctionReply { -} - message ExportReusableVariableRequest { ReusableVar reusable_variable = 1; // The reusable variable to export. } +message NotifyFailureRequest { + Failure failure = 1; // The failure object. +} + // These messages are for getting information about the object store state message ObjStoreInfoRequest { @@ -259,26 +262,21 @@ message ObjStoreInfoReply { // Workers service WorkerService { - rpc ExecuteTask(ExecuteTaskRequest) returns (ExecuteTaskReply); // Scheduler calls a function from the worker - rpc ImportFunction(ImportFunctionRequest) returns (ImportFunctionReply); // Scheduler imports a function into the worker + rpc ExecuteTask(ExecuteTaskRequest) returns (AckReply); // Scheduler calls a function from the worker + rpc ImportRemoteFunction(ImportRemoteFunctionRequest) returns (AckReply); // Scheduler imports a function into the worker rpc ImportReusableVariable(ImportReusableVariableRequest) returns (AckReply); // Scheduler imports a reusable variable into the worker - rpc Die(DieRequest) returns (DieReply); // Kills this worker + rpc Die(DieRequest) returns (AckReply); // Kills this worker + rpc PrintErrorMessage(PrintErrorMessageRequest) returns (AckReply); // Causes an error message to be printed. } message ExecuteTaskRequest { Task task = 1; // Contains name of the function to be executed and arguments } -message ExecuteTaskReply { -} - -message ImportFunctionRequest { +message ImportRemoteFunctionRequest { Function function = 1; } -message ImportFunctionReply { -} - message ImportReusableVariableRequest { ReusableVar reusable_variable = 1; // The reusable variable to export. } @@ -286,9 +284,6 @@ message ImportReusableVariableRequest { message DieRequest { } -message DieReply { -} - // This message is used by the worker service to send messages to the worker // that are processed by the worker's main loop. message WorkerMessage { @@ -298,3 +293,7 @@ message WorkerMessage { ReusableVar reusable_variable = 3; // A reusable variable to import on the worker. } } + +message PrintErrorMessageRequest { + Failure failure = 1; // The failure object. +} diff --git a/protos/types.proto b/protos/types.proto index 753902996..538066668 100644 --- a/protos/types.proto +++ b/protos/types.proto @@ -38,7 +38,8 @@ message PyObj { // Used for shipping remote functions to workers message Function { - bytes implementation = 1; + string name = 1; + bytes implementation = 2; } message ReusableVar { @@ -47,6 +48,24 @@ message ReusableVar { Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable. } +enum FailedType { + FailedTask = 0; + FailedRemoteFunctionImport = 1; + FailedReusableVariableImport = 2; + FailedReinitializeReusableVariable = 3; +} + +// Used to represent exceptions thrown in Python. This will happen when a task +// fails to execute, a remote function fails to be imported, or a reusable +// variable fails to be imported. +message Failure { + FailedType type = 1; // The type of the failure. + uint64 workerid = 2; // The id of the worker on which the failure occurred. + string worker_address = 3; // The address of the worker on which the failure occurred. This contains the same information as the workerid. + string name = 4; // The name of the failed object. + string error_message = 5; // The error message from the failure. +} + // Union of possible object types message Obj { String string_data = 1; diff --git a/src/raylib.cc b/src/raylib.cc index f30d0c600..faab42a2a 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -715,7 +715,10 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) { PyTuple_SetItem(t, 1, deserialize_task(worker_capsule, message->task())); } else if (function_present) { PyTuple_SetItem(t, 0, PyString_FromString("function")); - PyTuple_SetItem(t, 1, PyString_FromStringAndSize(message->function().implementation().data(), static_cast(message->function().implementation().size()))); + PyObject* remote_function_data = PyTuple_New(2); + PyTuple_SetItem(remote_function_data, 0, PyString_FromStringAndSize(message->function().name().data(), static_cast(message->function().name().size()))); + PyTuple_SetItem(remote_function_data, 1, PyString_FromStringAndSize(message->function().implementation().data(), static_cast(message->function().implementation().size()))); + PyTuple_SetItem(t, 1, remote_function_data); } else if (reusable_variable_present) { PyTuple_SetItem(t, 0, PyString_FromString("reusable_variable")); PyObject* reusable_variable = PyTuple_New(3); @@ -734,14 +737,15 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) { Py_RETURN_NONE; } -static PyObject* export_function(PyObject* self, PyObject* args) { +static PyObject* export_remote_function(PyObject* self, PyObject* args) { Worker* worker; + const char* function_name; const char* function; int function_size; - if (!PyArg_ParseTuple(args, "O&s#", &PyObjectToWorker, &worker, &function, &function_size)) { + if (!PyArg_ParseTuple(args, "O&ss#", &PyObjectToWorker, &worker, &function_name, &function, &function_size)) { return NULL; } - if (worker->export_function(std::string(function, static_cast(function_size)))) { + if (worker->export_remote_function(std::string(function_name), std::string(function, static_cast(function_size)))) { Py_RETURN_TRUE; } else { Py_RETURN_FALSE; @@ -795,25 +799,33 @@ static PyObject* submit_task(PyObject* self, PyObject* args) { static PyObject* notify_task_completed(PyObject* self, PyObject* args) { Worker* worker; - PyObject* task_succeeded_obj; - const char* error_message_ptr; - if (!PyArg_ParseTuple(args, "O&Os", &PyObjectToWorker, &worker, &task_succeeded_obj, &error_message_ptr)) { + if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) { return NULL; } - std::string error_message(error_message_ptr); - bool task_succeeded = PyObject_IsTrue(task_succeeded_obj); - worker->notify_task_completed(task_succeeded, error_message); + worker->notify_task_completed(); Py_RETURN_NONE; } -static PyObject* register_function(PyObject* self, PyObject* args) { +static PyObject* register_remote_function(PyObject* self, PyObject* args) { Worker* worker; const char* function_name; int num_return_vals; if (!PyArg_ParseTuple(args, "O&si", &PyObjectToWorker, &worker, &function_name, &num_return_vals)) { return NULL; } - worker->register_function(std::string(function_name), num_return_vals); + worker->register_remote_function(std::string(function_name), num_return_vals); + Py_RETURN_NONE; +} + +static PyObject* notify_failure(PyObject* self, PyObject* args) { + Worker* worker; + const char* name; + const char* error_message; + FailedType type; + if (!PyArg_ParseTuple(args, "O&ssi", &PyObjectToWorker, &worker, &name, &error_message, &type)) { + return NULL; + } + worker->notify_failure(type, std::string(name), std::string(error_message)); Py_RETURN_NONE; } @@ -887,10 +899,11 @@ static PyObject* alias_objectids(PyObject* self, PyObject* args) { static PyObject* start_worker_service(PyObject* self, PyObject* args) { Worker* worker; - if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) { + Mode mode; + if (!PyArg_ParseTuple(args, "O&i", &PyObjectToWorker, &worker, &mode)) { return NULL; } - worker->start_worker_service(); + worker->start_worker_service(mode); Py_RETURN_NONE; } @@ -919,6 +932,15 @@ static PyObject* scheduler_info(PyObject* self, PyObject* args) { return dict; } +static PyObject* failure_to_dict(const Failure& failure) { + PyObject* failure_dict = PyDict_New(); + set_dict_item_and_transfer_ownership(failure_dict, PyString_FromString("workerid"), PyInt_FromLong(failure.workerid())); + set_dict_item_and_transfer_ownership(failure_dict, PyString_FromString("worker_address"), PyString_FromStringAndSize(failure.worker_address().data(), failure.worker_address().size())); + set_dict_item_and_transfer_ownership(failure_dict, PyString_FromString("function_name"), PyString_FromStringAndSize(failure.name().data(), failure.name().size())); + set_dict_item_and_transfer_ownership(failure_dict, PyString_FromString("error_message"), PyString_FromStringAndSize(failure.error_message().data(), failure.error_message().size())); + return failure_dict; +} + static PyObject* task_info(PyObject* self, PyObject* args) { Worker* worker; if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) { @@ -950,10 +972,27 @@ static PyObject* task_info(PyObject* self, PyObject* args) { PyList_SetItem(running_tasks_list, i, info_dict); } + PyObject* failed_remote_function_imports = PyList_New(reply.failed_remote_function_import_size()); + for (size_t i = 0; i < reply.failed_remote_function_import_size(); ++i) { + PyList_SetItem(failed_remote_function_imports, i, failure_to_dict(reply.failed_remote_function_import(i))); + } + + PyObject* failed_reusable_variable_imports = PyList_New(reply.failed_reusable_variable_import_size()); + for (size_t i = 0; i < reply.failed_reusable_variable_import_size(); ++i) { + PyList_SetItem(failed_reusable_variable_imports, i, failure_to_dict(reply.failed_reusable_variable_import(i))); + } + + PyObject* failed_reinitialize_reusable_variables = PyList_New(reply.failed_reinitialize_reusable_variable_size()); + for (size_t i = 0; i < reply.failed_reinitialize_reusable_variable_size(); ++i) { + PyList_SetItem(failed_reinitialize_reusable_variables, i, failure_to_dict(reply.failed_reinitialize_reusable_variable(i))); + } + PyObject* dict = PyDict_New(); set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_tasks"), failed_tasks_list); set_dict_item_and_transfer_ownership(dict, PyString_FromString("running_tasks"), running_tasks_list); - set_dict_item_and_transfer_ownership(dict, PyString_FromString("num_succeeded"), PyInt_FromLong(reply.num_succeeded())); + set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_remote_function_imports"), failed_remote_function_imports); + set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_reusable_variable_imports"), failed_reusable_variable_imports); + set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_reinitialize_reusable_variables"), failed_reinitialize_reusable_variables); return dict; } @@ -1008,7 +1047,8 @@ static PyMethodDef RayLibMethods[] = { { "create_worker", create_worker, METH_VARARGS, "connect to the scheduler and the object store" }, { "disconnect", disconnect, METH_VARARGS, "disconnect the worker from the scheduler and the object store" }, { "connected", connected, METH_VARARGS, "check if the worker is connected to the scheduler and the object store" }, - { "register_function", register_function, METH_VARARGS, "register a function with the scheduler" }, + { "register_remote_function", register_remote_function, METH_VARARGS, "register a function with the scheduler" }, + { "notify_failure", notify_failure, METH_VARARGS, "notify the scheduler of a failure" }, { "put_object", put_object, METH_VARARGS, "put a protocol buffer object (given as a capsule) on the local object store" }, { "get_object", get_object, METH_VARARGS, "get protocol buffer object from the local object store" }, { "get_objectid", get_objectid, METH_VARARGS, "register a new object reference with the scheduler" }, @@ -1019,8 +1059,8 @@ static PyMethodDef RayLibMethods[] = { { "notify_task_completed", notify_task_completed, METH_VARARGS, "notify the scheduler that a task has been completed" }, { "start_worker_service", start_worker_service, METH_VARARGS, "start the worker service" }, { "scheduler_info", scheduler_info, METH_VARARGS, "get info about scheduler state" }, - { "task_info", task_info, METH_VARARGS, "get task statuses" }, - { "export_function", export_function, METH_VARARGS, "export function to workers" }, + { "task_info", task_info, METH_VARARGS, "get information about task statuses and failures" }, + { "export_remote_function", export_remote_function, METH_VARARGS, "export a remote function to workers" }, { "export_reusable_variable", export_reusable_variable, METH_VARARGS, "export a reusable variable to the workers" }, { "dump_computation_graph", dump_computation_graph, METH_VARARGS, "dump the current computation graph to a file" }, { "set_log_config", set_log_config, METH_VARARGS, "set filename for raylib logging" }, @@ -1046,6 +1086,20 @@ PyMODINIT_FUNC initlibraylib(void) { PyModule_AddObject(m, "ray_error", RayError); PyModule_AddObject(m, "ray_size_error", RaySizeError); import_array(); + + // Export constants used for the worker mode types so they can be accessed + // from Python. The Mode enum is defined in worker.h. + PyModule_AddIntConstant(m, "SCRIPT_MODE", Mode::SCRIPT_MODE); + PyModule_AddIntConstant(m, "WORKER_MODE", Mode::WORKER_MODE); + PyModule_AddIntConstant(m, "PYTHON_MODE", Mode::PYTHON_MODE); + PyModule_AddIntConstant(m, "SILENT_MODE", Mode::SILENT_MODE); + + // Export constants for the failure types so they can be accessed from Python. + // The FailedType enum is defined in types.proto. + PyModule_AddIntConstant(m, "FailedTask", FailedType::FailedTask); + PyModule_AddIntConstant(m, "FailedRemoteFunctionImport", FailedType::FailedRemoteFunctionImport); + PyModule_AddIntConstant(m, "FailedReusableVariableImport", FailedType::FailedReusableVariableImport); + PyModule_AddIntConstant(m, "FailedReinitializeReusableVariable", FailedType::FailedReinitializeReusableVariable); } } diff --git a/src/scheduler.cc b/src/scheduler.cc index 085e94af2..636c198be 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -256,7 +256,13 @@ Status SchedulerService::RegisterWorker(ServerContext* context, const RegisterWo { auto workers = GET(workers_); workerid = workers->size(); - worker_address = node_ip_address + ":" + std::to_string(40000 + workerid); + // Generate a random port number. This is currently a hack to avoid reusing + // port numbers when we run the tests. + std::random_device rd; + std::mt19937 rng(rd()); + std::uniform_int_distribution uni(0, 10000); + int port_number = 40000 + uni(rng); + worker_address = node_ip_address + ":" + std::to_string(port_number); workers->push_back(WorkerHandle()); auto channel = grpc::CreateChannel(worker_address, grpc::InsecureChannelCredentials()); (*workers)[workerid].channel = channel; @@ -279,13 +285,64 @@ Status SchedulerService::RegisterWorker(ServerContext* context, const RegisterWo return Status::OK; } -Status SchedulerService::RegisterFunction(ServerContext* context, const RegisterFunctionRequest* request, AckReply* reply) { - RAY_LOG(RAY_INFO, "register function " << request->fnname() << " from workerid " << request->workerid()); - register_function(request->fnname(), request->workerid(), request->num_return_vals()); +Status SchedulerService::RegisterRemoteFunction(ServerContext* context, const RegisterRemoteFunctionRequest* request, AckReply* reply) { + RAY_LOG(RAY_INFO, "register function " << request->function_name() << " from workerid " << request->workerid()); + register_function(request->function_name(), request->workerid(), request->num_return_vals()); schedule(); return Status::OK; } +Status SchedulerService::NotifyFailure(ServerContext* context, const NotifyFailureRequest* request, AckReply* reply) { + const Failure failure = request->failure(); + WorkerId workerid = failure.workerid(); + if (failure.type() == FailedType::FailedTask) { + // A task threw an exception while executing. + TaskStatus failed_task_info; + { + auto workers = GET(workers_); + failed_task_info.set_operationid((*workers)[workerid].current_task); + failed_task_info.set_function_name(failure.name()); + failed_task_info.set_worker_address((*workers)[workerid].worker_address); + failed_task_info.set_error_message(failure.error_message()); + } + GET(failed_tasks_)->push_back(failed_task_info); + RAY_LOG(RAY_INFO, "Error: Task " << failed_task_info.operationid() << " executing function " << failed_task_info.function_name() << " on worker " << workerid << " failed with error message:\n" << failed_task_info.error_message()); + } else if (failure.type() == FailedType::FailedRemoteFunctionImport) { + // An exception was thrown while a remote function was being imported. + GET(failed_remote_function_imports_)->push_back(failure); + RAY_LOG(RAY_INFO, "Error: Worker " << workerid << " failed to import remote function " << failure.name() << ", failed with error message:\n" << failure.error_message()); + } else if (failure.type() == FailedType::FailedReusableVariableImport) { + // An exception was thrown while a reusable variable was being imported. + GET(failed_reusable_variable_imports_)->push_back(failure); + RAY_LOG(RAY_INFO, "Error: Worker " << workerid << " failed to import reusable variable " << failure.name() << ", failed with error message:\n" << failure.error_message()); + } else if (failure.type() == FailedType::FailedReinitializeReusableVariable) { + // An exception was thrown while a reusable variable was being imported. + GET(failed_reinitialize_reusable_variables_)->push_back(failure); + RAY_LOG(RAY_INFO, "Error: Worker " << workerid << " failed to reinitialize a reusable variable after running remote function " << failure.name() << ", failed with error message:\n" << failure.error_message()); + } else { + RAY_CHECK(false, "This code should be unreachable.") + } + // Print the failure on the relevant driver. TODO(rkn): At the moment, this + // prints the failure on all of the drivers. It should probably only print it + // on the driver that caused the problem. + auto workers = GET(workers_); + for (size_t i = 0; i < workers->size(); ++i) { + WorkerHandle* worker = &(*workers)[i]; + // Check if the worker is still connected. + if (worker->worker_stub) { + // Check if this is a driver. + if (worker->current_task == ROOT_OPERATION) { + ClientContext client_context; + PrintErrorMessageRequest print_request; + print_request.mutable_failure()->CopyFrom(request->failure()); + AckReply print_reply; + Status status = worker->worker_stub->PrintErrorMessage(&client_context, print_request, &print_reply); + } + } + } + return Status::OK; +} + Status SchedulerService::ObjReady(ServerContext* context, const ObjReadyRequest* request, AckReply* reply) { ObjectID objectid = request->objectid(); RAY_LOG(RAY_DEBUG, "object " << objectid << " ready on store " << request->objstoreid()); @@ -306,43 +363,25 @@ Status SchedulerService::ObjReady(ServerContext* context, const ObjReadyRequest* Status SchedulerService::ReadyForNewTask(ServerContext* context, const ReadyForNewTaskRequest* request, AckReply* reply) { WorkerId workerid = request->workerid(); - OperationId operationid = (*GET(workers_))[workerid].current_task; - RAY_LOG(RAY_INFO, "worker " << workerid << " is ready for a new task"); - RAY_CHECK(operationid != ROOT_OPERATION, "A driver appears to have called ReadyForNewTask."); { - // Check if the worker has been initialized yet, and if not, then give it - // all of the exported functions and all of the exported reusable variables. auto workers = GET(workers_); - if (!(*workers)[workerid].initialized) { - // This should only happen once. - // Import all remote functions on the worker. - export_all_functions_to_worker(workerid, workers, GET(exported_functions_)); - // Import all reusable variables on the worker. - export_all_reusable_variables_to_worker(workerid, workers, GET(exported_reusable_variables_)); - // Mark the worker as initialized. - (*workers)[workerid].initialized = true; - } - } - if (request->has_previous_task_info()) { - RAY_CHECK(operationid != NO_OPERATION, "request->has_previous_task_info() should not be true if operationid == NO_OPERATION."); - std::string task_name; - task_name = GET(computation_graph_)->get_task(operationid).name(); - TaskStatus info; + OperationId operationid = (*workers)[workerid].current_task; + RAY_LOG(RAY_INFO, "worker " << workerid << " is ready for a new task"); + RAY_CHECK(operationid != ROOT_OPERATION, "A driver appears to have called ReadyForNewTask."); { - auto workers = GET(workers_); - info.set_operationid(operationid); - info.set_function_name(task_name); - info.set_worker_address((*workers)[workerid].worker_address); - info.set_error_message(request->previous_task_info().error_message()); - (*workers)[workerid].current_task = NO_OPERATION; // clear operation ID + // Check if the worker has been initialized yet, and if not, then give it + // all of the exported functions and all of the exported reusable variables. + if (!(*workers)[workerid].initialized) { + // This should only happen once. + // Import all remote functions on the worker. + export_all_functions_to_worker(workerid, workers, GET(exported_functions_)); + // Import all reusable variables on the worker. + export_all_reusable_variables_to_worker(workerid, workers, GET(exported_reusable_variables_)); + // Mark the worker as initialized. + (*workers)[workerid].initialized = true; + } } - if (!request->previous_task_info().task_succeeded()) { - RAY_LOG(RAY_INFO, "Error: Task " << info.operationid() << " executing function " << info.function_name() << " on worker " << workerid << " failed with error message:\n" << info.error_message()); - GET(failed_tasks_)->push_back(info); - } else { - GET(successful_tasks_)->push_back(info.operationid()); - } - // TODO(rkn): Handle task failure + (*workers)[workerid].current_task = NO_OPERATION; // clear operation ID } GET(avail_workers_)->push_back(workerid); schedule(); @@ -394,14 +433,18 @@ Status SchedulerService::SchedulerInfo(ServerContext* context, const SchedulerIn } Status SchedulerService::TaskInfo(ServerContext* context, const TaskInfoRequest* request, TaskInfoReply* reply) { - auto successful_tasks = GET(successful_tasks_); auto failed_tasks = GET(failed_tasks_); + auto failed_remote_function_imports = GET(failed_remote_function_imports_); + auto failed_reusable_variable_imports = GET(failed_reusable_variable_imports_); + auto failed_reinitialize_reusable_variables = GET(failed_reinitialize_reusable_variables_); auto computation_graph = GET(computation_graph_); auto workers = GET(workers_); + // Return information about the failed tasks. for (int i = 0; i < failed_tasks->size(); ++i) { TaskStatus* info = reply->add_failed_task(); *info = (*failed_tasks)[i]; } + // Return information about currently running tasks. for (size_t i = 0; i < workers->size(); ++i) { OperationId operationid = (*workers)[i].current_task; if (operationid != NO_OPERATION && operationid != ROOT_OPERATION) { @@ -412,7 +455,21 @@ Status SchedulerService::TaskInfo(ServerContext* context, const TaskInfoRequest* info->set_worker_address((*workers)[i].worker_address); } } - reply->set_num_succeeded(successful_tasks->size()); + // Return information about failed remote function imports. + for (size_t i = 0; i < failed_remote_function_imports->size(); ++i) { + Failure* failure = reply->add_failed_remote_function_import(); + *failure = (*failed_remote_function_imports)[i]; + } + // Return information about failed reusable variable imports. + for (size_t i = 0; i < failed_reusable_variable_imports->size(); ++i) { + Failure* failure = reply->add_failed_reusable_variable_import(); + *failure = (*failed_reusable_variable_imports)[i]; + } + // Return information about failed reusable variable reinitializations. + for (size_t i = 0; i < failed_reinitialize_reusable_variables->size(); ++i) { + Failure* failure = reply->add_failed_reinitialize_reusable_variable(); + *failure = (*failed_reinitialize_reusable_variables)[i]; + } return Status::OK; } @@ -449,7 +506,7 @@ Status SchedulerService::KillWorkers(ServerContext* context, const KillWorkersRe for (WorkerHandle* idle_worker : idle_workers) { ClientContext client_context; DieRequest die_request; - DieReply die_reply; + AckReply die_reply; // TODO: Fault handling... what if a worker refuses to die? We just assume it dies here. idle_worker->worker_stub->Die(&client_context, die_request, &die_reply); idle_worker->worker_stub.reset(); @@ -464,7 +521,7 @@ Status SchedulerService::KillWorkers(ServerContext* context, const KillWorkersRe return Status::OK; } -Status SchedulerService::ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) { +Status SchedulerService::ExportRemoteFunction(ServerContext* context, const ExportRemoteFunctionRequest* request, AckReply* reply) { auto workers = GET(workers_); auto exported_functions = GET(exported_functions_); // TODO(rkn): Does this do a deep copy? @@ -556,7 +613,7 @@ void SchedulerService::assign_task(OperationId operationid, WorkerId workerid, c const Task& task = computation_graph->get_task(operationid); ClientContext context; ExecuteTaskRequest request; - ExecuteTaskReply reply; + AckReply reply; RAY_LOG(RAY_INFO, "starting to send arguments"); for (size_t i = 0; i < task.arg_size(); ++i) { if (!task.arg(i).has_obj()) { @@ -970,10 +1027,10 @@ void SchedulerService::get_equivalent_objectids(ObjectID objectid, std::vector > &workers, const MySynchronizedPtr > > &exported_functions) { RAY_LOG(RAY_INFO, "exporting function with index " << function_index << " to worker " << workerid); ClientContext import_context; - ImportFunctionRequest import_request; + ImportRemoteFunctionRequest import_request; import_request.mutable_function()->CopyFrom(*(*exported_functions)[function_index].get()); - ImportFunctionReply import_reply; - (*workers)[workerid].worker_stub->ImportFunction(&import_context, import_request, &import_reply); + AckReply import_reply; + (*workers)[workerid].worker_stub->ImportRemoteFunction(&import_context, import_request, &import_reply); } void SchedulerService::export_reusable_variable_to_worker(WorkerId workerid, int reusable_variable_index, MySynchronizedPtr > &workers, const MySynchronizedPtr > > &exported_reusable_variables) { diff --git a/src/scheduler.h b/src/scheduler.h index 7bec6410a..45372463a 100644 --- a/src/scheduler.h +++ b/src/scheduler.h @@ -65,7 +65,7 @@ public: Status AliasObjectIDs(ServerContext* context, const AliasObjectIDsRequest* request, AckReply* reply) override; Status RegisterObjStore(ServerContext* context, const RegisterObjStoreRequest* request, RegisterObjStoreReply* reply) override; Status RegisterWorker(ServerContext* context, const RegisterWorkerRequest* request, RegisterWorkerReply* reply) override; - Status RegisterFunction(ServerContext* context, const RegisterFunctionRequest* request, AckReply* reply) override; + Status RegisterRemoteFunction(ServerContext* context, const RegisterRemoteFunctionRequest* request, AckReply* reply) override; Status ObjReady(ServerContext* context, const ObjReadyRequest* request, AckReply* reply) override; Status ReadyForNewTask(ServerContext* context, const ReadyForNewTaskRequest* request, AckReply* reply) override; Status IncrementRefCount(ServerContext* context, const IncrementRefCountRequest* request, AckReply* reply) override; @@ -74,8 +74,9 @@ public: Status SchedulerInfo(ServerContext* context, const SchedulerInfoRequest* request, SchedulerInfoReply* reply) override; Status TaskInfo(ServerContext* context, const TaskInfoRequest* request, TaskInfoReply* reply) override; Status KillWorkers(ServerContext* context, const KillWorkersRequest* request, KillWorkersReply* reply) override; - Status ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) override; + Status ExportRemoteFunction(ServerContext* context, const ExportRemoteFunctionRequest* request, AckReply* reply) override; Status ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) override; + Status NotifyFailure(ServerContext*, const NotifyFailureRequest* request, AckReply* reply) override; #ifdef NDEBUG // If we've disabled assertions, then just use regular SynchronizedPtr to skip lock checking. @@ -168,10 +169,14 @@ private: // When we unlock, we subtract back the field offset to restore it to the previous field that was locked. mutable Synchronized > > > lock_orders_; - // List of the IDs of successful tasks - Synchronized > successful_tasks_; // Right now, we only use this information in the TaskInfo call. // List of failed tasks Synchronized > failed_tasks_; + // A list of remote functions import failures. + Synchronized > failed_remote_function_imports_; + // A list of reusable variables import failures. + Synchronized > failed_reusable_variable_imports_; + // A list of reusable variables reinitialization failures. + Synchronized > failed_reinitialize_reusable_variables_; // List of pending get calls. Synchronized > > get_queue_; // The computation graph tracks the operations that have been submitted to the diff --git a/src/worker.cc b/src/worker.cc index 4a4589274..de43ece59 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -9,12 +9,14 @@ extern "C" { static PyObject *RayError; } -inline WorkerServiceImpl::WorkerServiceImpl(const std::string& worker_address) - : worker_address_(worker_address) { +inline WorkerServiceImpl::WorkerServiceImpl(const std::string& worker_address, Mode mode) + : worker_address_(worker_address), + mode_(mode) { RAY_CHECK(send_queue_.connect(worker_address_, false), "error connecting send_queue_"); } -Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) { +Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, AckReply* reply) { + RAY_CHECK(mode_ == Mode::WORKER_MODE, "ExecuteTask can only be called on workers."); RAY_LOG(RAY_INFO, "invoked task " << request->task().name()); std::unique_ptr message(new WorkerMessage()); message->mutable_task()->CopyFrom(request->task()); @@ -26,7 +28,8 @@ Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskR return Status::OK; } -Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) { +Status WorkerServiceImpl::ImportRemoteFunction(ServerContext* context, const ImportRemoteFunctionRequest* request, AckReply* reply) { + RAY_CHECK(mode_ == Mode::WORKER_MODE, "ImportRemoteFunction can only be called on workers."); std::unique_ptr message(new WorkerMessage()); message->mutable_function()->CopyFrom(request->function()); RAY_LOG(RAY_INFO, "importing function"); @@ -39,6 +42,7 @@ Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFun } Status WorkerServiceImpl::ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) { + RAY_CHECK(mode_ == Mode::WORKER_MODE, "ImportReusableVariable can only be called on workers."); std::unique_ptr message(new WorkerMessage()); message->mutable_reusable_variable()->CopyFrom(request->reusable_variable()); RAY_LOG(RAY_INFO, "importing reusable variable"); @@ -50,18 +54,46 @@ Status WorkerServiceImpl::ImportReusableVariable(ServerContext* context, const I return Status::OK; } -Status WorkerServiceImpl::Die(ServerContext* context, const DieRequest* request, DieReply* reply) { +Status WorkerServiceImpl::Die(ServerContext* context, const DieRequest* request, AckReply* reply) { + RAY_CHECK(mode_ == Mode::WORKER_MODE, "Die can only be called on workers."); WorkerMessage* message_ptr = NULL; RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC"); return Status::OK; } +Status WorkerServiceImpl::PrintErrorMessage(ServerContext* context, const PrintErrorMessageRequest* request, AckReply* reply) { + RAY_CHECK(mode_ != Mode::WORKER_MODE, "PrintErrorMessage can only be called on drivers."); + if (mode_ == Mode::SILENT_MODE) { + // Do not log error messages in this case. This is just used for the tests. + return Status::OK; + } + const Failure failure = request->failure(); + WorkerId workerid = failure.workerid(); + if (failure.type() == FailedType::FailedTask) { + // A task threw an exception while executing. + std::cout << "Error: Worker " << workerid << " failed to execute function " << failure.name() << ". Failed with error message:\n" << failure.error_message() << std::endl; + } else if (failure.type() == FailedType::FailedRemoteFunctionImport) { + // An exception was thrown while a remote function was being imported. + std::cout << "Error: Worker " << workerid << " failed to import remote function " << failure.name() << ", failed with error message:\n" << failure.error_message() << std::endl; + } else if (failure.type() == FailedType::FailedReusableVariableImport) { + // An exception was thrown while a reusable variable was being imported. + std::cout << "Error: Worker " << workerid << " failed to import reusable variable " << failure.name() << ", failed with error message:\n" << failure.error_message() << std::endl; + } else if (failure.type() == FailedType::FailedReinitializeReusableVariable) { + // An exception was thrown while a reusable variable was being reinitialized. + std::cout << "Error: Worker " << workerid << " failed to reinitialize a reusable variable after running remote function " << failure.name() << ", failed with error message:\n" << failure.error_message() << std::endl; + } else { + RAY_CHECK(false, "This code should be unreachable.") + } + return Status::OK; +} + Worker::Worker(const std::string& scheduler_address) : scheduler_address_(scheduler_address) { auto scheduler_channel = grpc::CreateChannel(scheduler_address, grpc::InsecureChannelCredentials()); scheduler_stub_ = Scheduler::NewStub(scheduler_channel); } + SubmitTaskReply Worker::submit_task(SubmitTaskRequest* request, int max_retries, int retry_wait_milliseconds) { RAY_CHECK(connected_, "Attempted to perform submit_task but failed."); SubmitTaskReply reply; @@ -312,15 +344,28 @@ void Worker::decrement_reference_count(std::vector &objectids) { } } -void Worker::register_function(const std::string& name, size_t num_return_vals) { +void Worker::register_remote_function(const std::string& name, size_t num_return_vals) { RAY_CHECK(connected_, "Attempted to perform register_function but failed."); ClientContext context; - RegisterFunctionRequest request; - request.set_fnname(name); - request.set_num_return_vals(num_return_vals); + RegisterRemoteFunctionRequest request; request.set_workerid(workerid_); + request.set_function_name(name); + request.set_num_return_vals(num_return_vals); AckReply reply; - scheduler_stub_->RegisterFunction(&context, request, &reply); + scheduler_stub_->RegisterRemoteFunction(&context, request, &reply); +} + +void Worker::notify_failure(FailedType type, const std::string& name, const std::string& error_message) { + RAY_CHECK(connected_, "Attempted to perform notify_failure but failed."); + ClientContext context; + NotifyFailureRequest request; + request.mutable_failure()->set_type(type); + request.mutable_failure()->set_workerid(workerid_); + request.mutable_failure()->set_worker_address(worker_address_); + request.mutable_failure()->set_name(name); + request.mutable_failure()->set_error_message(error_message); + AckReply reply; + scheduler_stub_->NotifyFailure(&context, request, &reply); } std::unique_ptr Worker::receive_next_message() { @@ -329,24 +374,19 @@ std::unique_ptr Worker::receive_next_message() { return std::unique_ptr(message_ptr); } -void Worker::notify_task_completed(bool task_succeeded, std::string error_message) { +void Worker::notify_task_completed() { RAY_CHECK(connected_, "Attempted to perform notify_task_completed but failed."); ClientContext context; ReadyForNewTaskRequest request; request.set_workerid(workerid_); - ReadyForNewTaskRequest::PreviousTaskInfo* previous_task_info = request.mutable_previous_task_info(); - previous_task_info->set_task_succeeded(task_succeeded); - previous_task_info->set_error_message(error_message); AckReply reply; scheduler_stub_->ReadyForNewTask(&context, request, &reply); } void Worker::disconnect() { connected_ = false; -} - -bool Worker::connected() { - return connected_; + // TODO(rkn): This probably isn't the right way to clean up the thread. + worker_server_thread_->detach(); } // TODO(rkn): Should we be using pointers or references? And should they be const? @@ -360,13 +400,14 @@ void Worker::task_info(ClientContext &context, TaskInfoRequest &request, TaskInf scheduler_stub_->TaskInfo(&context, request, &reply); } -bool Worker::export_function(const std::string& function) { +bool Worker::export_remote_function(const std::string& function_name, const std::string& function) { RAY_CHECK(connected_, "Attempted to export function but failed."); ClientContext context; - ExportFunctionRequest request; + ExportRemoteFunctionRequest request; + request.mutable_function()->set_name(function_name); request.mutable_function()->set_implementation(function); - ExportFunctionReply reply; - Status status = scheduler_stub_->ExportFunction(&context, request, &reply); + AckReply reply; + Status status = scheduler_stub_->ExportRemoteFunction(&context, request, &reply); return true; } @@ -385,26 +426,32 @@ void Worker::export_reusable_variable(const std::string& name, const std::string // queue. This is because the Python interpreter needs to be single threaded // (in our case running in the main thread), whereas the WorkerService will // run in a separate thread and potentially utilize multiple threads. -void Worker::start_worker_service() { +void Worker::start_worker_service(Mode mode) { const char* service_addr = worker_address_.c_str(); - worker_server_thread_ = std::thread([this, service_addr]() { + // Launch a new thread for running the worker service. We store this as a + // field so that we can clean it up when we disconnect the worker. + worker_server_thread_ = std::unique_ptr(new std::thread([this, service_addr, mode]() { std::string service_address(service_addr); std::string::iterator split_point = split_ip_address(service_address); std::string port; port.assign(split_point, service_address.end()); - WorkerServiceImpl service(service_address); + // Create the worker service. + WorkerServiceImpl service(service_address, mode); ServerBuilder builder; builder.AddListeningPort(std::string("0.0.0.0:") + port, grpc::InsecureServerCredentials()); builder.RegisterService(&service); std::unique_ptr server(builder.BuildAndStart()); RAY_LOG(RAY_INFO, "worker server listening on " << service_address); - - ClientContext context; - ReadyForNewTaskRequest request; - request.set_workerid(workerid_); - AckReply reply; - scheduler_stub_->ReadyForNewTask(&context, request, &reply); - + // If this is part of a worker process (and not a driver process), then tell + // the scheduler that it is ready to start receiving tasks. + if (mode == Mode::WORKER_MODE) { + ClientContext context; + ReadyForNewTaskRequest request; + request.set_workerid(workerid_); + AckReply reply; + scheduler_stub_->ReadyForNewTask(&context, request, &reply); + } + // Wait for work and process work, this does not return. server->Wait(); - }); + })); } diff --git a/src/worker.h b/src/worker.h index ff68abee6..b2920f7ee 100644 --- a/src/worker.h +++ b/src/worker.h @@ -23,16 +23,25 @@ using grpc::Channel; using grpc::ClientContext; using grpc::ClientWriter; +// These three constants are used to define the mode that a worker is running +// in. Right now, this is mostly used for determining how to print information +// about task failures. +enum Mode {SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE}; + class WorkerServiceImpl final : public WorkerService::Service { public: - WorkerServiceImpl(const std::string& worker_address); - Status ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) override; - Status ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) override; - Status Die(ServerContext* context, const DieRequest* request, DieReply* reply) override; + WorkerServiceImpl(const std::string& worker_address, Mode mode); + Status ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, AckReply* reply) override; + Status ImportRemoteFunction(ServerContext* context, const ImportRemoteFunctionRequest* request, AckReply* reply) override; + Status Die(ServerContext* context, const DieRequest* request, AckReply* reply) override; Status ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) override; + Status PrintErrorMessage(ServerContext* context, const PrintErrorMessageRequest* request, AckReply* reply) override; private: std::string worker_address_; MessageQueue send_queue_; + // This is true if the worker service is part of a driver process and false + // if it is part of a worker process. + Mode mode_; }; class Worker { @@ -71,27 +80,30 @@ class Worker { void increment_reference_count(std::vector &objectid); // decrement the reference count for objectid void decrement_reference_count(std::vector &objectid); - // register function with scheduler - void register_function(const std::string& name, size_t num_return_vals); - // start the worker server which accepts tasks from the scheduler and stores - // it in the message queue, which is read by the Python interpreter - void start_worker_service(); + // Notify the scheduler that a remote function has been imported successfully. + void register_remote_function(const std::string& name, size_t num_return_vals); + // Notify the scheduler that a failure has occurred. + void notify_failure(FailedType type, const std::string& name, const std::string& error_message); + // Start the worker server which accepts commands from the scheduler. For + // workers, these commands are stored in the message queue, which is read by + // the Python interpreter. For drivers, these commands are only for printing + // error messages. + void start_worker_service(Mode mode); // wait for next task from the RPC system. If null, it means there are no more tasks and the worker should shut down. std::unique_ptr receive_next_message(); // tell the scheduler that we are done with the current task and request the - // next one, if task_succeeded is false, this tells the scheduler that the - // task threw an exception - void notify_task_completed(bool task_succeeded, std::string error_message); + // next one. + void notify_task_completed(); // disconnect the worker void disconnect(); // return connected_ - bool connected(); + bool connected() { return connected_; } // get info about scheduler state void scheduler_info(ClientContext &context, SchedulerInfoRequest &request, SchedulerInfoReply &reply); // get task statuses from scheduler void task_info(ClientContext &context, TaskInfoRequest &request, TaskInfoReply &reply); // export function to workers - bool export_function(const std::string& function); + bool export_remote_function(const std::string& function_name, const std::string& function); // export reusable variable to workers void export_reusable_variable(const std::string& name, const std::string& initializer, const std::string& reinitializer); // return the worker address @@ -101,7 +113,7 @@ class Worker { bool connected_; const size_t CHUNK_SIZE = 8 * 1024; std::unique_ptr scheduler_stub_; - std::thread worker_server_thread_; + std::unique_ptr worker_server_thread_; MessageQueue receive_queue_; bip::managed_shared_memory segment_; WorkerId workerid_; diff --git a/test/runtest.py b/test/runtest.py index 0eacc5014..e60b5c66b 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -249,14 +249,12 @@ class APITest(unittest.TestCase): task_info = ray.task_info() self.assertEqual(len(task_info["failed_tasks"]), 0) self.assertEqual(len(task_info["running_tasks"]), 0) - self.assertEqual(task_info["num_succeeded"], 1) test_functions.no_op_fail.remote() time.sleep(0.2) task_info = ray.task_info() self.assertEqual(len(task_info["failed_tasks"]), 1) self.assertEqual(len(task_info["running_tasks"]), 0) - self.assertEqual(task_info["num_succeeded"], 1) self.assertTrue("The @remote decorator for function test_functions.no_op_fail has 0 return values, but test_functions.no_op_fail returned more than 0 values." in task_info["failed_tasks"][0].get("error_message")) ray.worker.cleanup() @@ -273,7 +271,6 @@ class APITest(unittest.TestCase): task_info = ray.task_info() self.assertEqual(len(task_info["failed_tasks"]), 2) self.assertEqual(len(task_info["running_tasks"]), 0) - self.assertEqual(task_info["num_succeeded"], 0) ray.worker.cleanup() @@ -391,6 +388,63 @@ class TaskStatusTest(unittest.TestCase): ray.worker.cleanup() + def testFailImportingRemoteFunction(self): + ray.init(start_ray_local=True, num_workers=2, driver_mode=ray.SILENT_MODE) + + # This example is somewhat contrived. It should be successfully pickled, and + # then it should throw an exception when it is unpickled. This may depend a + # bit on the specifics of our pickler. + def reducer(*args): + raise Exception("There is a problem here.") + class Foo(object): + def __init__(self): + self.__name__ = "Foo_object" + self.func_doc = "" + self.__globals__ = {} + def __reduce__(self): + return reducer, () + def __call__(self): + return + ray.remote([], [])(Foo()) + time.sleep(0.1) + self.assertTrue("There is a problem here." in ray.task_info()["failed_remote_function_imports"][0]["error_message"]) + + ray.worker.cleanup() + + def testFailImportingReusableVariable(self): + ray.init(start_ray_local=True, num_workers=2, driver_mode=ray.SILENT_MODE) + + # This will throw an exception when the reusable variable is imported on the + # workers. + def initializer(): + if ray.worker.global_worker.mode == ray.WORKER_MODE: + raise Exception("The initializer failed.") + return 0 + ray.reusables.foo = ray.Reusable(initializer) + time.sleep(0.1) + # Check that the error message is in the task info. + self.assertTrue("The initializer failed." in ray.task_info()["failed_reusable_variable_imports"][0]["error_message"]) + + ray.worker.cleanup() + + def testFailReinitializingVariable(self): + ray.init(start_ray_local=True, num_workers=2, driver_mode=ray.SILENT_MODE) + + def initializer(): + return 0 + def reinitializer(foo): + raise Exception("The reinitializer failed.") + ray.reusables.foo = ray.Reusable(initializer, reinitializer) + @ray.remote([], []) + def use_foo(): + ray.reusables.foo + use_foo.remote() + time.sleep(0.1) + # Check that the error message is in the task info. + self.assertTrue("The reinitializer failed." in ray.task_info()["failed_reinitialize_reusable_variables"][0]["error_message"]) + + ray.worker.cleanup() + def check_get_deallocated(data): x = ray.put(data) ray.get(x)