diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 5ed698c33..2ebdb1063 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -462,7 +462,7 @@ class Worker(object): not be used. """ # First run the function on the driver. - function() + function(self) # Then run the function on all of the workers. raylib.run_function_on_all_workers(self.handle, pickling.dumps(function)) @@ -649,6 +649,7 @@ def print_error_messages(worker=global_worker): num_failed_remote_function_imports = 0 num_failed_reusable_variable_imports = 0 num_failed_reusable_variable_reinitializations = 0 + num_failed_function_to_runs = 0 while True: try: info = task_info(worker=worker) @@ -668,6 +669,9 @@ def print_error_messages(worker=global_worker): for error in info["failed_reinitialize_reusable_variables"][num_failed_reusable_variable_reinitializations:]: print error["error_message"] num_failed_reusable_variable_reinitializations = len(info["failed_reinitialize_reusable_variables"]) + for error in info["failed_function_to_runs"][num_failed_function_to_runs:]: + print error["error_message"] + num_failed_function_to_runs = len(info["failed_function_to_runs"]) except RayConnectionError: # When the driver is exiting, we set worker.handle to None, which will cause # the check_connected call inside of task_info to raise an exception. We use @@ -731,8 +735,8 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, worker=gl # the same. script_directory = os.path.abspath(os.path.dirname(sys.argv[0])) current_directory = os.path.abspath(os.path.curdir) - worker.run_function_on_all_workers(lambda : sys.path.insert(1, script_directory)) - worker.run_function_on_all_workers(lambda : sys.path.insert(1, current_directory)) + worker.run_function_on_all_workers(lambda worker : sys.path.insert(1, script_directory)) + worker.run_function_on_all_workers(lambda worker : sys.path.insert(1, current_directory)) # Export cached remote functions to the workers. for function_name, function_to_export in worker.cached_remote_functions: raylib.export_remote_function(worker.handle, function_name, function_to_export) @@ -986,14 +990,15 @@ def main_loop(worker=global_worker): # Deserialize the function. function = pickling.loads(serialized_function) # Run the function. - function() + function(worker) except: # If an exception was thrown when the function was run, we record the # traceback and notify the scheduler of the failure. - traceback_str = format_error_message(traceback.format_exc()) + traceback_str = traceback.format_exc() _logger().info("Failed to run function on worker. Failed with message: \n\n{}\n".format(traceback_str)) # Notify the scheduler that running the function failed. - # TODO(rkn): Notify the scheduler. + name = function.__name__ if "function" in locals() and hasattr(function, "__name__") else "" + raylib.notify_failure(worker.handle, name, traceback_str, raylib.FailedFunctionToRun) else: _logger().info("Successfully ran function on worker.") diff --git a/protos/ray.proto b/protos/ray.proto index 916e5d2d6..c9ee0dabb 100644 --- a/protos/ray.proto +++ b/protos/ray.proto @@ -240,6 +240,7 @@ message TaskInfoReply { repeated Failure failed_remote_function_import = 3; // The remote function imports that failed. repeated Failure failed_reusable_variable_import = 4; // The reusable variable imports that failed. repeated Failure failed_reinitialize_reusable_variable = 5; // The reusable variable reinitializations that failed. + repeated Failure failed_function_to_run = 6; // The function to run on all workers that failed. } message KillWorkersRequest { diff --git a/protos/types.proto b/protos/types.proto index ece53bbda..093c92a9d 100644 --- a/protos/types.proto +++ b/protos/types.proto @@ -53,6 +53,7 @@ enum FailedType { FailedRemoteFunctionImport = 1; FailedReusableVariableImport = 2; FailedReinitializeReusableVariable = 3; + FailedFunctionToRun = 4; } // Used to represent exceptions thrown in Python. This will happen when a task diff --git a/src/raylib.cc b/src/raylib.cc index b5adee611..c7cf1b217 100644 --- a/src/raylib.cc +++ b/src/raylib.cc @@ -1044,12 +1044,18 @@ static PyObject* task_info(PyObject* self, PyObject* args) { PyList_SetItem(failed_reinitialize_reusable_variables, i, failure_to_dict(reply.failed_reinitialize_reusable_variable(i))); } + PyObject* failed_function_to_runs = PyList_New(reply.failed_function_to_run_size()); + for (size_t i = 0; i < reply.failed_function_to_run_size(); ++i) { + PyList_SetItem(failed_function_to_runs, i, failure_to_dict(reply.failed_function_to_run(i))); + } + PyObject* dict = PyDict_New(); set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_tasks"), failed_tasks_list); set_dict_item_and_transfer_ownership(dict, PyString_FromString("running_tasks"), running_tasks_list); set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_remote_function_imports"), failed_remote_function_imports); set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_reusable_variable_imports"), failed_reusable_variable_imports); set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_reinitialize_reusable_variables"), failed_reinitialize_reusable_variables); + set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_function_to_runs"), failed_function_to_runs); return dict; } @@ -1146,6 +1152,7 @@ PyMODINIT_FUNC initlibraylib(void) { PyModule_AddIntConstant(m, "FailedRemoteFunctionImport", FailedType::FailedRemoteFunctionImport); PyModule_AddIntConstant(m, "FailedReusableVariableImport", FailedType::FailedReusableVariableImport); PyModule_AddIntConstant(m, "FailedReinitializeReusableVariable", FailedType::FailedReinitializeReusableVariable); + PyModule_AddIntConstant(m, "FailedFunctionToRun", FailedType::FailedFunctionToRun); } } diff --git a/src/scheduler.cc b/src/scheduler.cc index 3faac2dbe..9e5a3cfd1 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -311,6 +311,10 @@ Status SchedulerService::NotifyFailure(ServerContext* context, const NotifyFailu // An exception was thrown while a reusable variable was being imported. GET(failed_reinitialize_reusable_variables_)->push_back(failure); RAY_LOG(RAY_INFO, "Error: Worker " << workerid << " failed to reinitialize a reusable variable after running remote function " << failure.name() << ", failed with error message:\n" << failure.error_message()); + } else if (failure.type() == FailedType::FailedFunctionToRun) { + // An exception was thrown while a function was being run on all workers. + GET(failed_function_to_runs_)->push_back(failure); + RAY_LOG(RAY_INFO, "Error: Worker " << workerid << " failed to run function " << failure.name() << " on all workers, failed with error message:\n" << failure.error_message()); } else { RAY_CHECK(false, "This code should be unreachable.") } @@ -431,6 +435,7 @@ Status SchedulerService::TaskInfo(ServerContext* context, const TaskInfoRequest* auto failed_remote_function_imports = GET(failed_remote_function_imports_); auto failed_reusable_variable_imports = GET(failed_reusable_variable_imports_); auto failed_reinitialize_reusable_variables = GET(failed_reinitialize_reusable_variables_); + auto failed_function_to_runs = GET(failed_function_to_runs_); auto computation_graph = GET(computation_graph_); auto workers = GET(workers_); // Return information about the failed tasks. @@ -464,6 +469,11 @@ Status SchedulerService::TaskInfo(ServerContext* context, const TaskInfoRequest* Failure* failure = reply->add_failed_reinitialize_reusable_variable(); *failure = (*failed_reinitialize_reusable_variables)[i]; } + // Return information about functions that failed to run on all workers. + for (size_t i = 0; i < failed_function_to_runs->size(); ++i) { + Failure* failure = reply->add_failed_function_to_run(); + *failure = (*failed_function_to_runs)[i]; + } return Status::OK; } diff --git a/src/scheduler.h b/src/scheduler.h index bccc321a4..a9a13e93d 100644 --- a/src/scheduler.h +++ b/src/scheduler.h @@ -188,6 +188,8 @@ private: Synchronized > failed_reusable_variable_imports_; // A list of reusable variables reinitialization failures. Synchronized > failed_reinitialize_reusable_variables_; + // A list of function to run failures. + Synchronized > failed_function_to_runs_; // List of pending get calls. Synchronized > > get_queue_; // The computation graph tracks the operations that have been submitted to the diff --git a/src/worker.cc b/src/worker.cc index eb7c9c6cf..917d18ef3 100644 --- a/src/worker.cc +++ b/src/worker.cc @@ -100,6 +100,9 @@ Status WorkerServiceImpl::PrintErrorMessage(ServerContext* context, const PrintE } else if (failure.type() == FailedType::FailedReinitializeReusableVariable) { // An exception was thrown while a reusable variable was being reinitialized. std::cout << "Error: Worker " << workerid << " failed to reinitialize a reusable variable after running remote function " << failure.name() << ", failed with error message:\n" << failure.error_message() << std::endl; + } else if (failure.type() == FailedType::FailedFunctionToRun) { + // An exception was thrown while a function was being run on all workers. + std::cout << "Error: Worker " << workerid << " failed to run function " << failure.name() << " on all workers, failed with error message:\n" << failure.error_message() << std::endl; } else { RAY_CHECK(false, "This code should be unreachable.") } diff --git a/test/failure_test.py b/test/failure_test.py index 48017b322..347af0024 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -124,5 +124,23 @@ class TaskStatusTest(unittest.TestCase): ray.worker.cleanup() + def testFailedFunctionToRun(self): + ray.init(start_ray_local=True, num_workers=2, driver_mode=ray.SILENT_MODE) + + def f(worker): + if ray.worker.global_worker.mode == ray.WORKER_MODE: + raise Exception("Function to run failed.") + ray.worker.global_worker.run_function_on_all_workers(f) + for _ in range(100): # Retry if we need to wait longer. + if len(ray.task_info()["failed_function_to_runs"]) >= 2: + break + time.sleep(0.1) + # Check that the error message is in the task info. + self.assertEqual(len(ray.task_info()["failed_function_to_runs"]), 2) + self.assertTrue("Function to run failed." in ray.task_info()["failed_function_to_runs"][0]["error_message"]) + self.assertTrue("Function to run failed." in ray.task_info()["failed_function_to_runs"][1]["error_message"]) + + ray.worker.cleanup() + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/test/runtest.py b/test/runtest.py index 441648442..d0fc80dd5 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -358,14 +358,14 @@ class APITest(unittest.TestCase): def testRunningFunctionOnAllWorkers(self): ray.init(start_ray_local=True, num_workers=1) - def f(): + def f(worker): sys.path.append("fake_directory") ray.worker.global_worker.run_function_on_all_workers(f) @ray.remote def get_path(): return sys.path self.assertEqual("fake_directory", ray.get(get_path.remote())[-1]) - def f(): + def f(worker): sys.path.pop(-1) ray.worker.global_worker.run_function_on_all_workers(f) self.assertTrue("fake_directory" not in ray.get(get_path.remote()))