Catch errors in importing reusable variables and remote functions (#354)

* catch errors in importing reusable variables and remote functions

* updates
This commit is contained in:
Robert Nishihara
2016-08-07 13:53:33 -07:00
committed by Philipp Moritz
parent a6452aca47
commit a1e4268d37
10 changed files with 512 additions and 264 deletions
+2 -2
View File
@@ -11,8 +11,8 @@ if hasattr(ctypes, "windll"):
import config
import serialization
from worker import scheduler_info, visualize_computation_graph, task_info, register_module, init, connect, disconnect, get, put, remote, kill_workers, restart_workers_local
from worker import scheduler_info, visualize_computation_graph, task_info, init, connect, disconnect, get, put, remote, kill_workers, restart_workers_local
from worker import Reusable, reusables
from worker import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE
from libraylib import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE
from libraylib import ObjectID
import internal
+112 -111
View File
@@ -20,14 +20,6 @@ import services
import libnumbuf
import libraylib as raylib
# These three constants are used to define the mode that a worker is running in.
# Right now, this is only used for determining how to print information about
# task failures.
SCRIPT_MODE = 0
WORKER_MODE = 1
PYTHON_MODE = 2
SILENT_MODE = 3 # This is only used during testing.
class RayTaskError(Exception):
"""An object used internally to represent a task that threw an exception.
@@ -341,7 +333,7 @@ class RayReusables(object):
raise Exception("To set a reusable variable, you must pass in a Reusable object")
self._names.add(name)
self._reusables[name] = reusable
if _mode() in [SCRIPT_MODE, SILENT_MODE]:
if _mode() in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
_export_reusable_variable(name, reusable)
elif _mode() is None:
self._cached_reusables.append((name, reusable))
@@ -369,11 +361,13 @@ class Worker(object):
handle (worker capsule): A Python object wrapping a C++ Worker object.
mode: The mode of the worker. One of SCRIPT_MODE, PYTHON_MODE, SILENT_MODE,
and WORKER_MODE.
cached_remote_functions (List[str]): A list of serialized remote functions
that were defined before the worker called connect. When the worker
cached_remote_functions (List[Tuple[str, str]]): A list of pairs
representing the remote functions that were defined before he worker
called connect. The first element is the name of the remote function, and
the second element is the serialized remote function. When the worker
eventually does call connect, if it is a driver, it will export these
functions to the scheduler. If cached_remote_functions is None, that means
that connect has been called.
that connect has been called already.
num_failed_tasks (int): The number of tasks that have failed and whose error
messages have been displayed to the user. We use this value to know when
a failed task hasn't been seen by the user and should be displayed.
@@ -507,20 +501,6 @@ class Worker(object):
"""Make two object IDs refer to the same object."""
raylib.alias_objectids(self.handle, alias_objectid, target_objectid)
def register_function(self, function):
"""Register a function with the scheduler.
Notify the scheduler that this worker can execute the function with name
func_name. After this call, the scheduler can send tasks for executing
the function to this worker.
Args:
function (Callable): The remote function that this worker can execute.
"""
_logger().info("Registering function {}.".format(function.func_name))
raylib.register_function(self.handle, function.func_name, len(function.return_types))
self.functions[function.func_name] = function
def submit_task(self, func_name, args):
"""Submit a remote task to the scheduler.
@@ -536,23 +516,8 @@ class Worker(object):
"""
task_capsule = serialization.serialize_task(self.handle, func_name, args)
objectids = raylib.submit_task(self.handle, task_capsule)
if self.mode == SCRIPT_MODE:
self.print_new_failures()
return objectids
def print_new_failures(self):
"""Print information about tasks."""
task_data = raylib.task_info(self.handle)
num_tasks_succeeded = task_data["num_succeeded"]
num_tasks_in_progress = len(task_data["running_tasks"])
num_new_tasks_failed = len(task_data["failed_tasks"]) - self.num_failed_tasks
if num_new_tasks_failed > 0:
# Print the new tasks that have failed.
for task_status in task_data["failed_tasks"][self.num_failed_tasks:]:
print_failed_task(task_status)
print "{}Error: {} new task{} failed.{}".format(colorama.Fore.RED, num_new_tasks_failed, "s" if num_new_tasks_failed > 1 else "", colorama.Fore.RESET)
self.num_failed_tasks = len(task_data["failed_tasks"])
global_worker = Worker()
"""Worker: The global Worker object for this worker process.
@@ -643,24 +608,7 @@ def task_info(worker=global_worker):
check_connected(worker)
return raylib.task_info(worker.handle)
def register_module(module, worker=global_worker):
"""Register each remote function in a module with the scheduler.
This registers each remote function in the module with the scheduler, so tasks
with those functions can be scheduled on this worker.
args:
module (module): The module of functions to register.
"""
check_connected(worker)
_logger().info("registering functions in module {}.".format(module.__name__))
for name in dir(module):
val = getattr(module, name)
if hasattr(val, "is_remote") and val.is_remote:
_logger().info("registering {}.".format(val.func_name))
worker.register_function(val)
def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_address=None, node_ip_address=None, driver_mode=SCRIPT_MODE):
def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_address=None, node_ip_address=None, driver_mode=raylib.SCRIPT_MODE):
"""Either connect to an existing Ray cluster or start one and connect to it.
This method handles two cases. Either a Ray cluster already exists and we
@@ -692,8 +640,8 @@ def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_
# and we connect to them.
if (scheduler_address is not None) or (node_ip_address is not None):
raise Exception("If start_ray_local=True, then you cannot pass in a scheduler_address or a node_ip_address.")
if driver_mode not in [SCRIPT_MODE, PYTHON_MODE, SILENT_MODE]:
raise Exception("If start_ray_local=True, then driver_mode must be in [SCRIPT_MODE, PYTHON_MODE, SILENT_MODE].")
if driver_mode not in [raylib.SCRIPT_MODE, raylib.PYTHON_MODE, raylib.SILENT_MODE]:
raise Exception("If start_ray_local=True, then driver_mode must be in [ray.SCRIPT_MODE, ray.PYTHON_MODE, ray.SILENT_MODE].")
# Use the address 127.0.0.1 in local mode.
node_ip_address = "127.0.0.1"
num_workers = 1 if num_workers is None else num_workers
@@ -727,7 +675,7 @@ def cleanup(worker=global_worker):
atexit.register(cleanup)
def connect(node_ip_address, scheduler_address, objstore_address=None, is_driver=False, worker=global_worker, mode=WORKER_MODE):
def connect(node_ip_address, scheduler_address, objstore_address=None, is_driver=False, worker=global_worker, mode=raylib.WORKER_MODE):
"""Connect this worker to the scheduler and an object store.
Args:
@@ -757,13 +705,19 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, is_driver
_logger().propagate = False
# Configure the logging from the worker C++ code.
raylib.set_log_config(config.get_log_file_path("-".join(["worker", worker.worker_address, "c++"]) + ".log"))
if mode in [SCRIPT_MODE, SILENT_MODE]:
for function_to_export in worker.cached_remote_functions:
raylib.export_function(worker.handle, function_to_export)
if mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
for function_name, function_to_export in worker.cached_remote_functions:
raylib.export_remote_function(worker.handle, function_name, function_to_export)
for name, reusable_variable in reusables._cached_reusables:
_export_reusable_variable(name, reusable_variable)
worker.cached_remote_functions = None
reusables._cached_reusables = None
# Start the driver's WorkerService (if this is a driver). This will receive
# GRPC commands from the scheduler to print error messages. We pass in the
# mode below. This tells the WorkerService whether it is operating for a
# driver or a worker and whether it should surpress errors or not.
if is_driver:
raylib.start_worker_service(worker.handle, mode)
def disconnect(worker=global_worker):
"""Disconnect this worker from the scheduler and object store."""
@@ -791,11 +745,9 @@ def get(objectid, worker=global_worker):
A Python object
"""
check_connected(worker)
if worker.mode == PYTHON_MODE:
return objectid # In PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid)
if worker.mode == raylib.PYTHON_MODE:
return objectid # In raylib.PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid)
raylib.request_object(worker.handle, objectid)
if worker.mode == SCRIPT_MODE:
worker.print_new_failures()
value = worker.get_object(objectid)
if isinstance(value, RayTaskError):
# If the result is a RayTaskError, then the task that created this object
@@ -813,12 +765,10 @@ def put(value, worker=global_worker):
The object ID assigned to this value.
"""
check_connected(worker)
if worker.mode == PYTHON_MODE:
return value # In PYTHON_MODE, ray.put is the identity operation
if worker.mode == raylib.PYTHON_MODE:
return value # In raylib.PYTHON_MODE, ray.put is the identity operation
objectid = raylib.get_objectid(worker.handle)
worker.put_object(objectid, value)
if worker.mode == SCRIPT_MODE:
worker.print_new_failures()
return objectid
def kill_workers(worker=global_worker):
@@ -878,27 +828,33 @@ def format_error_message(exception_message):
def main_loop(worker=global_worker):
"""The main loop a worker runs to receive and execute tasks.
This method is an infinite loop. It waits to receive tasks from the scheduler.
When it receives a task, it first deserializes the task. Then it retrieves the
values for any arguments that were passed in as object IDs. Then it
passes the arguments to the actual function. Then it stores the outputs of the
function in the local object store. Then it notifies the scheduler that it
completed the task.
If the process of getting the arguments for execution (which does some type
checking) or the process of executing the task fail, then the main loop will
catch the exception and store RayTaskError objects containing the relevant
error messages in the object store in place of the actual outputs. These
objects are used to propagate the error messages.
This method is an infinite loop. It waits to receive commands from the
scheduler. A command may consist of a task to execute, a remote function to
import, a reusable variable to import, or an order to terminate the worker
process. The worker executes the command, notifies the scheduler of any errors
that occurred while executing the command, and waits for the next command.
"""
if not raylib.connected(worker.handle):
raise Exception("Worker is attempting to enter main_loop but has not been connected yet.")
raylib.start_worker_service(worker.handle)
# We pass in raylib.WORKER_MODE below to indicate that the WorkerService is
# operating for a worker and not a driver.
raylib.start_worker_service(worker.handle, raylib.WORKER_MODE)
def process_task(task): # wrapping these lines in a function should cause the local variables to go out of scope more quickly, which is useful for inspecting reference counts
func_name, args, return_objectids = serialization.deserialize_task(worker.handle, task)
"""Execute a task assigned to this worker.
This method deserializes a task from the scheduler, and attempts to execute
the task. If the task succeeds, the outputs are stored in the local object
store. If the task throws an exception, RayTaskError objects are stored in
the object store to represent the failed task (these will be retrieved by
calls to get or by subsequent tasks that use the outputs of this task).
After the task executes, the worker resets any reusable variables that were
accessed by the task.
"""
function_name, args, return_objectids = serialization.deserialize_task(worker.handle, task)
try:
arguments = get_arguments_for_execution(worker.functions[func_name], args, worker) # get args from objstore
outputs = worker.functions[func_name].executor(arguments) # execute the function
arguments = get_arguments_for_execution(worker.functions[function_name], args, worker) # get args from objstore
outputs = worker.functions[function_name].executor(arguments) # execute the function
if len(return_objectids) == 1:
outputs = (outputs,)
except Exception as e:
@@ -906,37 +862,82 @@ def main_loop(worker=global_worker):
# whether the exception was thrown in the task execution by whether the
# variable "arguments" is defined.
traceback_str = format_error_message(traceback.format_exc()) if "arguments" in locals() else None
failure_object = RayTaskError(func_name, e, traceback_str)
failure_object = RayTaskError(function_name, e, traceback_str)
failure_objects = [failure_object for _ in range(len(return_objectids))]
store_outputs_in_objstore(return_objectids, failure_objects, worker)
raylib.notify_task_completed(worker.handle, False, str(failure_object))
_logger().info("Worker threw exception with message: \n\n{}\n, while running function {}.".format(str(failure_object), func_name))
# Notify the scheduler that the task failed.
raylib.notify_failure(worker.handle, function_name, str(failure_object), raylib.FailedTask)
_logger().info("While running function {}, worker threw exception with message: \n\n{}\n".format(function_name, str(failure_object)))
else:
store_outputs_in_objstore(return_objectids, outputs, worker) # store output in local object store
raylib.notify_task_completed(worker.handle, True, "") # notify the scheduler that the task completed successfully
finally:
# Notify the scheduler that the task is done. This happens regardless of
# whether the task succeeded or failed.
raylib.notify_task_completed(worker.handle)
try:
# Reinitialize the values of reusable variables that were used in the task
# above so that changes made to their state do not affect other tasks.
reusables._reinitialize()
except Exception as e:
# The attempt to reinitialize the reusable variables threw an exception.
# We record the traceback and notify the scheduler.
traceback_str = format_error_message(traceback.format_exc())
raylib.notify_failure(worker.handle, function_name, traceback_str, raylib.FailedReinitializeReusableVariable)
_logger().info("While attempting to reinitialize the reusable variables after running function {}, the worker threw exception with message: \n\n{}\n".format(function_name, traceback_str))
def process_remote_function(function_name, serialized_function):
"""Import a remote function."""
try:
(function, arg_types, return_types, module) = pickling.loads(serialized_function)
except:
# If an exception was thrown when the remote function was imported, we
# record the traceback and notify the scheduler of the failure.
traceback_str = format_error_message(traceback.format_exc())
_logger().info("Failed to import remote function {}. Failed with message: \n\n{}\n".format(function_name, traceback_str))
# Notify the scheduler that the remote function failed to import.
raylib.notify_failure(worker.handle, function_name, traceback_str, raylib.FailedRemoteFunctionImport)
else:
# TODO(rkn): Why is the below line necessary?
function.__module__ = module
assert function_name == "{}.{}".format(function.__module__, function.__name__), "The remote function name does not match the name that was passed in."
worker.functions[function_name] = remote(arg_types, return_types, worker)(function)
_logger().info("Successfully imported remote function {}.".format(function_name))
# Noify the scheduler that the remote function imported successfully.
# We pass an empty error message string because the import succeeded.
raylib.register_remote_function(worker.handle, function_name, len(return_types))
def process_reusable_variable(reusable_variable_name, initializer_str, reinitializer_str):
"""Import a reusable variable."""
try:
initializer = pickling.loads(initializer_str)
reinitializer = pickling.loads(reinitializer_str)
reusables.__setattr__(reusable_variable_name, Reusable(initializer, reinitializer))
except:
# If an exception was thrown when the reusable variable was imported, we
# record the traceback and notify the scheduler of the failure.
traceback_str = format_error_message(traceback.format_exc())
_logger().info("Failed to import reusable variable {}. Failed with message: \n\n{}\n".format(reusable_variable_name, traceback_str))
# Notify the scheduler that the reusable variable failed to import.
raylib.notify_failure(worker.handle, reusable_variable_name, traceback_str, raylib.FailedReusableVariableImport)
else:
_logger().info("Successfully imported reusable variable {}.".format(reusable_variable_name))
while True:
command, command_args = raylib.wait_for_next_message(worker.handle)
try:
if command == "die":
# We use this as a mechanism to allow the scheduler to kill workers.
_logger().info("Received a 'die' command, and will exit now.")
break
elif command == "function":
(function, arg_types, return_types, module) = pickling.loads(command_args)
# TODO(rkn): Why is the below line necessary?
function.__module__ = module
worker.register_function(remote(arg_types, return_types, worker)(function))
elif command == "reusable_variable":
name, initializer_str, reinitializer_str = command_args
initializer = pickling.loads(initializer_str)
reinitializer = pickling.loads(reinitializer_str)
reusables.__setattr__(name, Reusable(initializer, reinitializer))
elif command == "task":
process_task(command_args)
elif command == "function":
function_name, serialized_function = command_args
process_remote_function(function_name, serialized_function)
elif command == "reusable_variable":
name, initializer_str, reinitializer_str = command_args
process_reusable_variable(name, initializer_str, reinitializer_str)
else:
_logger().info("Reached the end of the if-else loop in the main loop. This should be unreachable.")
assert False, "This code should be unreachable."
finally:
# Allow releasing the variables BEFORE we wait for the next message or exit the block
@@ -979,7 +980,7 @@ def _export_reusable_variable(name, reusable, worker=global_worker):
reusable (Reusable): The reusable object containing code for initializing
and reinitializing the variable.
"""
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
if _mode(worker) not in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
raise Exception("_export_reusable_variable can only be called on a driver.")
raylib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer))
@@ -996,8 +997,8 @@ def remote(arg_types, return_types, worker=global_worker):
check_connected()
args = list(args)
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments
if _mode() == PYTHON_MODE:
# In PYTHON_MODE, remote calls simply execute the function. We copy the
if _mode() == raylib.PYTHON_MODE:
# In raylib.PYTHON_MODE, remote calls simply execute the function. We copy the
# arguments to prevent the function call from mutating them and to match
# the usual behavior of immutable remote objects.
return func(*copy.deepcopy(args))
@@ -1035,7 +1036,7 @@ def remote(arg_types, return_types, worker=global_worker):
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
# Everything ready - export the function
if worker.mode in [None, SCRIPT_MODE, SILENT_MODE]:
if worker.mode in [None, raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
func_name_global_valid = func.__name__ in func.__globals__
func_name_global_value = func.__globals__.get(func.__name__)
# Set the function globally to make it refer to itself
@@ -1046,10 +1047,10 @@ def remote(arg_types, return_types, worker=global_worker):
# Undo our changes
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
else: del func.__globals__[func.__name__]
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
raylib.export_function(worker.handle, to_export)
if worker.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
raylib.export_remote_function(worker.handle, func_name, to_export)
elif worker.mode is None:
worker.cached_remote_functions.append(to_export)
worker.cached_remote_functions.append((func_name, to_export))
return func_invoker
return remote_decorator
+31 -32
View File
@@ -22,8 +22,8 @@ service Scheduler {
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerReply);
// Register an object store with the scheduler
rpc RegisterObjStore(RegisterObjStoreRequest) returns (RegisterObjStoreReply);
// Tell the scheduler that a worker can execute a certain function
rpc RegisterFunction(RegisterFunctionRequest) returns (AckReply);
// Tell the scheduler that a worker successfully imported a remote function.
rpc RegisterRemoteFunction(RegisterRemoteFunctionRequest) returns (AckReply);
// Asks the scheduler to execute a task, immediately returns an object ID to the result
rpc SubmitTask(SubmitTaskRequest) returns (SubmitTaskReply);
// Increment the count of the object ID
@@ -53,9 +53,11 @@ service Scheduler {
// Kills the workers
rpc KillWorkers(KillWorkersRequest) returns (KillWorkersReply);
// Exports function to the workers
rpc ExportFunction(ExportFunctionRequest) returns (ExportFunctionReply);
rpc ExportRemoteFunction(ExportRemoteFunctionRequest) returns (AckReply);
// Ship an initializer and reinitializer for a reusable variable to the workers
rpc ExportReusableVariable(ExportReusableVariableRequest) returns (AckReply);
// Notify the scheduler that a failure occurred while running a task, importing a remote function, or importing a reusable variable.
rpc NotifyFailure(NotifyFailureRequest) returns (AckReply);
}
message AckReply {
@@ -82,10 +84,14 @@ message RegisterObjStoreReply {
uint64 objstoreid = 1; // Object store ID assigned by the scheduler
}
message RegisterFunctionRequest {
message RegisterRemoteFunctionRequest {
uint64 workerid = 1; // Worker that can execute the function
string fnname = 2; // Name of the function that is registered
uint64 num_return_vals = 3; // Number of return values of the function
string function_name = 2; // Name of the remote function
uint64 num_return_vals = 3; // Number of return values of the function. This is only present if the function was successfully imported.
}
message NotifyFailure {
Failure failure = 1; // The failure object.
}
message SubmitTaskRequest {
@@ -136,11 +142,6 @@ message DecrementRefCountRequest {
message ReadyForNewTaskRequest {
uint64 workerid = 1; // ID of the worker which executed the task
message PreviousTaskInfo {
bool task_succeeded = 1; // True if the task succeeded, false if it threw an exception
string error_message = 2; // The contents of the exception, if the task threw an exception
}
PreviousTaskInfo previous_task_info = 2; // Information about the previous task, this is only present if there was a previous task
}
message ChangeCountRequest {
@@ -221,10 +222,11 @@ message TaskInfoRequest {
}
message TaskInfoReply {
repeated TaskStatus failed_task = 1;
repeated TaskStatus running_task = 2;
uint64 num_succeeded = 3;
// TODO(mehrdadn): We'll want to return information from computation_graph since it's important for visualizing tasks that have been completed etc.
repeated TaskStatus failed_task = 1; // The tasks that have failed.
repeated TaskStatus running_task = 2; // The tasks that are currently running.
repeated Failure failed_remote_function_import = 3; // The remote function imports that failed.
repeated Failure failed_reusable_variable_import = 4; // The reusable variable imports that failed.
repeated Failure failed_reinitialize_reusable_variable = 5; // The reusable variable reinitializations that failed.
}
message KillWorkersRequest {
@@ -234,17 +236,18 @@ message KillWorkersReply {
bool success = 1; // Currently, the only reason to fail is if there are workers still executing tasks
}
message ExportFunctionRequest {
message ExportRemoteFunctionRequest {
Function function = 1;
}
message ExportFunctionReply {
}
message ExportReusableVariableRequest {
ReusableVar reusable_variable = 1; // The reusable variable to export.
}
message NotifyFailureRequest {
Failure failure = 1; // The failure object.
}
// These messages are for getting information about the object store state
message ObjStoreInfoRequest {
@@ -259,26 +262,21 @@ message ObjStoreInfoReply {
// Workers
service WorkerService {
rpc ExecuteTask(ExecuteTaskRequest) returns (ExecuteTaskReply); // Scheduler calls a function from the worker
rpc ImportFunction(ImportFunctionRequest) returns (ImportFunctionReply); // Scheduler imports a function into the worker
rpc ExecuteTask(ExecuteTaskRequest) returns (AckReply); // Scheduler calls a function from the worker
rpc ImportRemoteFunction(ImportRemoteFunctionRequest) returns (AckReply); // Scheduler imports a function into the worker
rpc ImportReusableVariable(ImportReusableVariableRequest) returns (AckReply); // Scheduler imports a reusable variable into the worker
rpc Die(DieRequest) returns (DieReply); // Kills this worker
rpc Die(DieRequest) returns (AckReply); // Kills this worker
rpc PrintErrorMessage(PrintErrorMessageRequest) returns (AckReply); // Causes an error message to be printed.
}
message ExecuteTaskRequest {
Task task = 1; // Contains name of the function to be executed and arguments
}
message ExecuteTaskReply {
}
message ImportFunctionRequest {
message ImportRemoteFunctionRequest {
Function function = 1;
}
message ImportFunctionReply {
}
message ImportReusableVariableRequest {
ReusableVar reusable_variable = 1; // The reusable variable to export.
}
@@ -286,9 +284,6 @@ message ImportReusableVariableRequest {
message DieRequest {
}
message DieReply {
}
// This message is used by the worker service to send messages to the worker
// that are processed by the worker's main loop.
message WorkerMessage {
@@ -298,3 +293,7 @@ message WorkerMessage {
ReusableVar reusable_variable = 3; // A reusable variable to import on the worker.
}
}
message PrintErrorMessageRequest {
Failure failure = 1; // The failure object.
}
+20 -1
View File
@@ -38,7 +38,8 @@ message PyObj {
// Used for shipping remote functions to workers
message Function {
bytes implementation = 1;
string name = 1;
bytes implementation = 2;
}
message ReusableVar {
@@ -47,6 +48,24 @@ message ReusableVar {
Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable.
}
enum FailedType {
FailedTask = 0;
FailedRemoteFunctionImport = 1;
FailedReusableVariableImport = 2;
FailedReinitializeReusableVariable = 3;
}
// Used to represent exceptions thrown in Python. This will happen when a task
// fails to execute, a remote function fails to be imported, or a reusable
// variable fails to be imported.
message Failure {
FailedType type = 1; // The type of the failure.
uint64 workerid = 2; // The id of the worker on which the failure occurred.
string worker_address = 3; // The address of the worker on which the failure occurred. This contains the same information as the workerid.
string name = 4; // The name of the failed object.
string error_message = 5; // The error message from the failure.
}
// Union of possible object types
message Obj {
String string_data = 1;
+72 -18
View File
@@ -715,7 +715,10 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) {
PyTuple_SetItem(t, 1, deserialize_task(worker_capsule, message->task()));
} else if (function_present) {
PyTuple_SetItem(t, 0, PyString_FromString("function"));
PyTuple_SetItem(t, 1, PyString_FromStringAndSize(message->function().implementation().data(), static_cast<ssize_t>(message->function().implementation().size())));
PyObject* remote_function_data = PyTuple_New(2);
PyTuple_SetItem(remote_function_data, 0, PyString_FromStringAndSize(message->function().name().data(), static_cast<ssize_t>(message->function().name().size())));
PyTuple_SetItem(remote_function_data, 1, PyString_FromStringAndSize(message->function().implementation().data(), static_cast<ssize_t>(message->function().implementation().size())));
PyTuple_SetItem(t, 1, remote_function_data);
} else if (reusable_variable_present) {
PyTuple_SetItem(t, 0, PyString_FromString("reusable_variable"));
PyObject* reusable_variable = PyTuple_New(3);
@@ -734,14 +737,15 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) {
Py_RETURN_NONE;
}
static PyObject* export_function(PyObject* self, PyObject* args) {
static PyObject* export_remote_function(PyObject* self, PyObject* args) {
Worker* worker;
const char* function_name;
const char* function;
int function_size;
if (!PyArg_ParseTuple(args, "O&s#", &PyObjectToWorker, &worker, &function, &function_size)) {
if (!PyArg_ParseTuple(args, "O&ss#", &PyObjectToWorker, &worker, &function_name, &function, &function_size)) {
return NULL;
}
if (worker->export_function(std::string(function, static_cast<size_t>(function_size)))) {
if (worker->export_remote_function(std::string(function_name), std::string(function, static_cast<size_t>(function_size)))) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
@@ -795,25 +799,33 @@ static PyObject* submit_task(PyObject* self, PyObject* args) {
static PyObject* notify_task_completed(PyObject* self, PyObject* args) {
Worker* worker;
PyObject* task_succeeded_obj;
const char* error_message_ptr;
if (!PyArg_ParseTuple(args, "O&Os", &PyObjectToWorker, &worker, &task_succeeded_obj, &error_message_ptr)) {
if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) {
return NULL;
}
std::string error_message(error_message_ptr);
bool task_succeeded = PyObject_IsTrue(task_succeeded_obj);
worker->notify_task_completed(task_succeeded, error_message);
worker->notify_task_completed();
Py_RETURN_NONE;
}
static PyObject* register_function(PyObject* self, PyObject* args) {
static PyObject* register_remote_function(PyObject* self, PyObject* args) {
Worker* worker;
const char* function_name;
int num_return_vals;
if (!PyArg_ParseTuple(args, "O&si", &PyObjectToWorker, &worker, &function_name, &num_return_vals)) {
return NULL;
}
worker->register_function(std::string(function_name), num_return_vals);
worker->register_remote_function(std::string(function_name), num_return_vals);
Py_RETURN_NONE;
}
static PyObject* notify_failure(PyObject* self, PyObject* args) {
Worker* worker;
const char* name;
const char* error_message;
FailedType type;
if (!PyArg_ParseTuple(args, "O&ssi", &PyObjectToWorker, &worker, &name, &error_message, &type)) {
return NULL;
}
worker->notify_failure(type, std::string(name), std::string(error_message));
Py_RETURN_NONE;
}
@@ -887,10 +899,11 @@ static PyObject* alias_objectids(PyObject* self, PyObject* args) {
static PyObject* start_worker_service(PyObject* self, PyObject* args) {
Worker* worker;
if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) {
Mode mode;
if (!PyArg_ParseTuple(args, "O&i", &PyObjectToWorker, &worker, &mode)) {
return NULL;
}
worker->start_worker_service();
worker->start_worker_service(mode);
Py_RETURN_NONE;
}
@@ -919,6 +932,15 @@ static PyObject* scheduler_info(PyObject* self, PyObject* args) {
return dict;
}
static PyObject* failure_to_dict(const Failure& failure) {
PyObject* failure_dict = PyDict_New();
set_dict_item_and_transfer_ownership(failure_dict, PyString_FromString("workerid"), PyInt_FromLong(failure.workerid()));
set_dict_item_and_transfer_ownership(failure_dict, PyString_FromString("worker_address"), PyString_FromStringAndSize(failure.worker_address().data(), failure.worker_address().size()));
set_dict_item_and_transfer_ownership(failure_dict, PyString_FromString("function_name"), PyString_FromStringAndSize(failure.name().data(), failure.name().size()));
set_dict_item_and_transfer_ownership(failure_dict, PyString_FromString("error_message"), PyString_FromStringAndSize(failure.error_message().data(), failure.error_message().size()));
return failure_dict;
}
static PyObject* task_info(PyObject* self, PyObject* args) {
Worker* worker;
if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) {
@@ -950,10 +972,27 @@ static PyObject* task_info(PyObject* self, PyObject* args) {
PyList_SetItem(running_tasks_list, i, info_dict);
}
PyObject* failed_remote_function_imports = PyList_New(reply.failed_remote_function_import_size());
for (size_t i = 0; i < reply.failed_remote_function_import_size(); ++i) {
PyList_SetItem(failed_remote_function_imports, i, failure_to_dict(reply.failed_remote_function_import(i)));
}
PyObject* failed_reusable_variable_imports = PyList_New(reply.failed_reusable_variable_import_size());
for (size_t i = 0; i < reply.failed_reusable_variable_import_size(); ++i) {
PyList_SetItem(failed_reusable_variable_imports, i, failure_to_dict(reply.failed_reusable_variable_import(i)));
}
PyObject* failed_reinitialize_reusable_variables = PyList_New(reply.failed_reinitialize_reusable_variable_size());
for (size_t i = 0; i < reply.failed_reinitialize_reusable_variable_size(); ++i) {
PyList_SetItem(failed_reinitialize_reusable_variables, i, failure_to_dict(reply.failed_reinitialize_reusable_variable(i)));
}
PyObject* dict = PyDict_New();
set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_tasks"), failed_tasks_list);
set_dict_item_and_transfer_ownership(dict, PyString_FromString("running_tasks"), running_tasks_list);
set_dict_item_and_transfer_ownership(dict, PyString_FromString("num_succeeded"), PyInt_FromLong(reply.num_succeeded()));
set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_remote_function_imports"), failed_remote_function_imports);
set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_reusable_variable_imports"), failed_reusable_variable_imports);
set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_reinitialize_reusable_variables"), failed_reinitialize_reusable_variables);
return dict;
}
@@ -1008,7 +1047,8 @@ static PyMethodDef RayLibMethods[] = {
{ "create_worker", create_worker, METH_VARARGS, "connect to the scheduler and the object store" },
{ "disconnect", disconnect, METH_VARARGS, "disconnect the worker from the scheduler and the object store" },
{ "connected", connected, METH_VARARGS, "check if the worker is connected to the scheduler and the object store" },
{ "register_function", register_function, METH_VARARGS, "register a function with the scheduler" },
{ "register_remote_function", register_remote_function, METH_VARARGS, "register a function with the scheduler" },
{ "notify_failure", notify_failure, METH_VARARGS, "notify the scheduler of a failure" },
{ "put_object", put_object, METH_VARARGS, "put a protocol buffer object (given as a capsule) on the local object store" },
{ "get_object", get_object, METH_VARARGS, "get protocol buffer object from the local object store" },
{ "get_objectid", get_objectid, METH_VARARGS, "register a new object reference with the scheduler" },
@@ -1019,8 +1059,8 @@ static PyMethodDef RayLibMethods[] = {
{ "notify_task_completed", notify_task_completed, METH_VARARGS, "notify the scheduler that a task has been completed" },
{ "start_worker_service", start_worker_service, METH_VARARGS, "start the worker service" },
{ "scheduler_info", scheduler_info, METH_VARARGS, "get info about scheduler state" },
{ "task_info", task_info, METH_VARARGS, "get task statuses" },
{ "export_function", export_function, METH_VARARGS, "export function to workers" },
{ "task_info", task_info, METH_VARARGS, "get information about task statuses and failures" },
{ "export_remote_function", export_remote_function, METH_VARARGS, "export a remote function to workers" },
{ "export_reusable_variable", export_reusable_variable, METH_VARARGS, "export a reusable variable to the workers" },
{ "dump_computation_graph", dump_computation_graph, METH_VARARGS, "dump the current computation graph to a file" },
{ "set_log_config", set_log_config, METH_VARARGS, "set filename for raylib logging" },
@@ -1046,6 +1086,20 @@ PyMODINIT_FUNC initlibraylib(void) {
PyModule_AddObject(m, "ray_error", RayError);
PyModule_AddObject(m, "ray_size_error", RaySizeError);
import_array();
// Export constants used for the worker mode types so they can be accessed
// from Python. The Mode enum is defined in worker.h.
PyModule_AddIntConstant(m, "SCRIPT_MODE", Mode::SCRIPT_MODE);
PyModule_AddIntConstant(m, "WORKER_MODE", Mode::WORKER_MODE);
PyModule_AddIntConstant(m, "PYTHON_MODE", Mode::PYTHON_MODE);
PyModule_AddIntConstant(m, "SILENT_MODE", Mode::SILENT_MODE);
// Export constants for the failure types so they can be accessed from Python.
// The FailedType enum is defined in types.proto.
PyModule_AddIntConstant(m, "FailedTask", FailedType::FailedTask);
PyModule_AddIntConstant(m, "FailedRemoteFunctionImport", FailedType::FailedRemoteFunctionImport);
PyModule_AddIntConstant(m, "FailedReusableVariableImport", FailedType::FailedReusableVariableImport);
PyModule_AddIntConstant(m, "FailedReinitializeReusableVariable", FailedType::FailedReinitializeReusableVariable);
}
}
+102 -45
View File
@@ -256,7 +256,13 @@ Status SchedulerService::RegisterWorker(ServerContext* context, const RegisterWo
{
auto workers = GET(workers_);
workerid = workers->size();
worker_address = node_ip_address + ":" + std::to_string(40000 + workerid);
// Generate a random port number. This is currently a hack to avoid reusing
// port numbers when we run the tests.
std::random_device rd;
std::mt19937 rng(rd());
std::uniform_int_distribution<int> uni(0, 10000);
int port_number = 40000 + uni(rng);
worker_address = node_ip_address + ":" + std::to_string(port_number);
workers->push_back(WorkerHandle());
auto channel = grpc::CreateChannel(worker_address, grpc::InsecureChannelCredentials());
(*workers)[workerid].channel = channel;
@@ -279,13 +285,64 @@ Status SchedulerService::RegisterWorker(ServerContext* context, const RegisterWo
return Status::OK;
}
Status SchedulerService::RegisterFunction(ServerContext* context, const RegisterFunctionRequest* request, AckReply* reply) {
RAY_LOG(RAY_INFO, "register function " << request->fnname() << " from workerid " << request->workerid());
register_function(request->fnname(), request->workerid(), request->num_return_vals());
Status SchedulerService::RegisterRemoteFunction(ServerContext* context, const RegisterRemoteFunctionRequest* request, AckReply* reply) {
RAY_LOG(RAY_INFO, "register function " << request->function_name() << " from workerid " << request->workerid());
register_function(request->function_name(), request->workerid(), request->num_return_vals());
schedule();
return Status::OK;
}
Status SchedulerService::NotifyFailure(ServerContext* context, const NotifyFailureRequest* request, AckReply* reply) {
const Failure failure = request->failure();
WorkerId workerid = failure.workerid();
if (failure.type() == FailedType::FailedTask) {
// A task threw an exception while executing.
TaskStatus failed_task_info;
{
auto workers = GET(workers_);
failed_task_info.set_operationid((*workers)[workerid].current_task);
failed_task_info.set_function_name(failure.name());
failed_task_info.set_worker_address((*workers)[workerid].worker_address);
failed_task_info.set_error_message(failure.error_message());
}
GET(failed_tasks_)->push_back(failed_task_info);
RAY_LOG(RAY_INFO, "Error: Task " << failed_task_info.operationid() << " executing function " << failed_task_info.function_name() << " on worker " << workerid << " failed with error message:\n" << failed_task_info.error_message());
} else if (failure.type() == FailedType::FailedRemoteFunctionImport) {
// An exception was thrown while a remote function was being imported.
GET(failed_remote_function_imports_)->push_back(failure);
RAY_LOG(RAY_INFO, "Error: Worker " << workerid << " failed to import remote function " << failure.name() << ", failed with error message:\n" << failure.error_message());
} else if (failure.type() == FailedType::FailedReusableVariableImport) {
// An exception was thrown while a reusable variable was being imported.
GET(failed_reusable_variable_imports_)->push_back(failure);
RAY_LOG(RAY_INFO, "Error: Worker " << workerid << " failed to import reusable variable " << failure.name() << ", failed with error message:\n" << failure.error_message());
} else if (failure.type() == FailedType::FailedReinitializeReusableVariable) {
// An exception was thrown while a reusable variable was being imported.
GET(failed_reinitialize_reusable_variables_)->push_back(failure);
RAY_LOG(RAY_INFO, "Error: Worker " << workerid << " failed to reinitialize a reusable variable after running remote function " << failure.name() << ", failed with error message:\n" << failure.error_message());
} else {
RAY_CHECK(false, "This code should be unreachable.")
}
// Print the failure on the relevant driver. TODO(rkn): At the moment, this
// prints the failure on all of the drivers. It should probably only print it
// on the driver that caused the problem.
auto workers = GET(workers_);
for (size_t i = 0; i < workers->size(); ++i) {
WorkerHandle* worker = &(*workers)[i];
// Check if the worker is still connected.
if (worker->worker_stub) {
// Check if this is a driver.
if (worker->current_task == ROOT_OPERATION) {
ClientContext client_context;
PrintErrorMessageRequest print_request;
print_request.mutable_failure()->CopyFrom(request->failure());
AckReply print_reply;
Status status = worker->worker_stub->PrintErrorMessage(&client_context, print_request, &print_reply);
}
}
}
return Status::OK;
}
Status SchedulerService::ObjReady(ServerContext* context, const ObjReadyRequest* request, AckReply* reply) {
ObjectID objectid = request->objectid();
RAY_LOG(RAY_DEBUG, "object " << objectid << " ready on store " << request->objstoreid());
@@ -306,43 +363,25 @@ Status SchedulerService::ObjReady(ServerContext* context, const ObjReadyRequest*
Status SchedulerService::ReadyForNewTask(ServerContext* context, const ReadyForNewTaskRequest* request, AckReply* reply) {
WorkerId workerid = request->workerid();
OperationId operationid = (*GET(workers_))[workerid].current_task;
RAY_LOG(RAY_INFO, "worker " << workerid << " is ready for a new task");
RAY_CHECK(operationid != ROOT_OPERATION, "A driver appears to have called ReadyForNewTask.");
{
// Check if the worker has been initialized yet, and if not, then give it
// all of the exported functions and all of the exported reusable variables.
auto workers = GET(workers_);
if (!(*workers)[workerid].initialized) {
// This should only happen once.
// Import all remote functions on the worker.
export_all_functions_to_worker(workerid, workers, GET(exported_functions_));
// Import all reusable variables on the worker.
export_all_reusable_variables_to_worker(workerid, workers, GET(exported_reusable_variables_));
// Mark the worker as initialized.
(*workers)[workerid].initialized = true;
}
}
if (request->has_previous_task_info()) {
RAY_CHECK(operationid != NO_OPERATION, "request->has_previous_task_info() should not be true if operationid == NO_OPERATION.");
std::string task_name;
task_name = GET(computation_graph_)->get_task(operationid).name();
TaskStatus info;
OperationId operationid = (*workers)[workerid].current_task;
RAY_LOG(RAY_INFO, "worker " << workerid << " is ready for a new task");
RAY_CHECK(operationid != ROOT_OPERATION, "A driver appears to have called ReadyForNewTask.");
{
auto workers = GET(workers_);
info.set_operationid(operationid);
info.set_function_name(task_name);
info.set_worker_address((*workers)[workerid].worker_address);
info.set_error_message(request->previous_task_info().error_message());
(*workers)[workerid].current_task = NO_OPERATION; // clear operation ID
// Check if the worker has been initialized yet, and if not, then give it
// all of the exported functions and all of the exported reusable variables.
if (!(*workers)[workerid].initialized) {
// This should only happen once.
// Import all remote functions on the worker.
export_all_functions_to_worker(workerid, workers, GET(exported_functions_));
// Import all reusable variables on the worker.
export_all_reusable_variables_to_worker(workerid, workers, GET(exported_reusable_variables_));
// Mark the worker as initialized.
(*workers)[workerid].initialized = true;
}
}
if (!request->previous_task_info().task_succeeded()) {
RAY_LOG(RAY_INFO, "Error: Task " << info.operationid() << " executing function " << info.function_name() << " on worker " << workerid << " failed with error message:\n" << info.error_message());
GET(failed_tasks_)->push_back(info);
} else {
GET(successful_tasks_)->push_back(info.operationid());
}
// TODO(rkn): Handle task failure
(*workers)[workerid].current_task = NO_OPERATION; // clear operation ID
}
GET(avail_workers_)->push_back(workerid);
schedule();
@@ -394,14 +433,18 @@ Status SchedulerService::SchedulerInfo(ServerContext* context, const SchedulerIn
}
Status SchedulerService::TaskInfo(ServerContext* context, const TaskInfoRequest* request, TaskInfoReply* reply) {
auto successful_tasks = GET(successful_tasks_);
auto failed_tasks = GET(failed_tasks_);
auto failed_remote_function_imports = GET(failed_remote_function_imports_);
auto failed_reusable_variable_imports = GET(failed_reusable_variable_imports_);
auto failed_reinitialize_reusable_variables = GET(failed_reinitialize_reusable_variables_);
auto computation_graph = GET(computation_graph_);
auto workers = GET(workers_);
// Return information about the failed tasks.
for (int i = 0; i < failed_tasks->size(); ++i) {
TaskStatus* info = reply->add_failed_task();
*info = (*failed_tasks)[i];
}
// Return information about currently running tasks.
for (size_t i = 0; i < workers->size(); ++i) {
OperationId operationid = (*workers)[i].current_task;
if (operationid != NO_OPERATION && operationid != ROOT_OPERATION) {
@@ -412,7 +455,21 @@ Status SchedulerService::TaskInfo(ServerContext* context, const TaskInfoRequest*
info->set_worker_address((*workers)[i].worker_address);
}
}
reply->set_num_succeeded(successful_tasks->size());
// Return information about failed remote function imports.
for (size_t i = 0; i < failed_remote_function_imports->size(); ++i) {
Failure* failure = reply->add_failed_remote_function_import();
*failure = (*failed_remote_function_imports)[i];
}
// Return information about failed reusable variable imports.
for (size_t i = 0; i < failed_reusable_variable_imports->size(); ++i) {
Failure* failure = reply->add_failed_reusable_variable_import();
*failure = (*failed_reusable_variable_imports)[i];
}
// Return information about failed reusable variable reinitializations.
for (size_t i = 0; i < failed_reinitialize_reusable_variables->size(); ++i) {
Failure* failure = reply->add_failed_reinitialize_reusable_variable();
*failure = (*failed_reinitialize_reusable_variables)[i];
}
return Status::OK;
}
@@ -449,7 +506,7 @@ Status SchedulerService::KillWorkers(ServerContext* context, const KillWorkersRe
for (WorkerHandle* idle_worker : idle_workers) {
ClientContext client_context;
DieRequest die_request;
DieReply die_reply;
AckReply die_reply;
// TODO: Fault handling... what if a worker refuses to die? We just assume it dies here.
idle_worker->worker_stub->Die(&client_context, die_request, &die_reply);
idle_worker->worker_stub.reset();
@@ -464,7 +521,7 @@ Status SchedulerService::KillWorkers(ServerContext* context, const KillWorkersRe
return Status::OK;
}
Status SchedulerService::ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) {
Status SchedulerService::ExportRemoteFunction(ServerContext* context, const ExportRemoteFunctionRequest* request, AckReply* reply) {
auto workers = GET(workers_);
auto exported_functions = GET(exported_functions_);
// TODO(rkn): Does this do a deep copy?
@@ -556,7 +613,7 @@ void SchedulerService::assign_task(OperationId operationid, WorkerId workerid, c
const Task& task = computation_graph->get_task(operationid);
ClientContext context;
ExecuteTaskRequest request;
ExecuteTaskReply reply;
AckReply reply;
RAY_LOG(RAY_INFO, "starting to send arguments");
for (size_t i = 0; i < task.arg_size(); ++i) {
if (!task.arg(i).has_obj()) {
@@ -970,10 +1027,10 @@ void SchedulerService::get_equivalent_objectids(ObjectID objectid, std::vector<O
void SchedulerService::export_function_to_worker(WorkerId workerid, int function_index, MySynchronizedPtr<std::vector<WorkerHandle> > &workers, const MySynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions) {
RAY_LOG(RAY_INFO, "exporting function with index " << function_index << " to worker " << workerid);
ClientContext import_context;
ImportFunctionRequest import_request;
ImportRemoteFunctionRequest import_request;
import_request.mutable_function()->CopyFrom(*(*exported_functions)[function_index].get());
ImportFunctionReply import_reply;
(*workers)[workerid].worker_stub->ImportFunction(&import_context, import_request, &import_reply);
AckReply import_reply;
(*workers)[workerid].worker_stub->ImportRemoteFunction(&import_context, import_request, &import_reply);
}
void SchedulerService::export_reusable_variable_to_worker(WorkerId workerid, int reusable_variable_index, MySynchronizedPtr<std::vector<WorkerHandle> > &workers, const MySynchronizedPtr<std::vector<std::unique_ptr<ReusableVar> > > &exported_reusable_variables) {
+9 -4
View File
@@ -65,7 +65,7 @@ public:
Status AliasObjectIDs(ServerContext* context, const AliasObjectIDsRequest* request, AckReply* reply) override;
Status RegisterObjStore(ServerContext* context, const RegisterObjStoreRequest* request, RegisterObjStoreReply* reply) override;
Status RegisterWorker(ServerContext* context, const RegisterWorkerRequest* request, RegisterWorkerReply* reply) override;
Status RegisterFunction(ServerContext* context, const RegisterFunctionRequest* request, AckReply* reply) override;
Status RegisterRemoteFunction(ServerContext* context, const RegisterRemoteFunctionRequest* request, AckReply* reply) override;
Status ObjReady(ServerContext* context, const ObjReadyRequest* request, AckReply* reply) override;
Status ReadyForNewTask(ServerContext* context, const ReadyForNewTaskRequest* request, AckReply* reply) override;
Status IncrementRefCount(ServerContext* context, const IncrementRefCountRequest* request, AckReply* reply) override;
@@ -74,8 +74,9 @@ public:
Status SchedulerInfo(ServerContext* context, const SchedulerInfoRequest* request, SchedulerInfoReply* reply) override;
Status TaskInfo(ServerContext* context, const TaskInfoRequest* request, TaskInfoReply* reply) override;
Status KillWorkers(ServerContext* context, const KillWorkersRequest* request, KillWorkersReply* reply) override;
Status ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) override;
Status ExportRemoteFunction(ServerContext* context, const ExportRemoteFunctionRequest* request, AckReply* reply) override;
Status ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) override;
Status NotifyFailure(ServerContext*, const NotifyFailureRequest* request, AckReply* reply) override;
#ifdef NDEBUG
// If we've disabled assertions, then just use regular SynchronizedPtr to skip lock checking.
@@ -168,10 +169,14 @@ private:
// When we unlock, we subtract back the field offset to restore it to the previous field that was locked.
mutable Synchronized<std::vector<std::pair<unsigned long long, std::pair<size_t, const char*> > > > lock_orders_;
// List of the IDs of successful tasks
Synchronized<std::vector<OperationId> > successful_tasks_; // Right now, we only use this information in the TaskInfo call.
// List of failed tasks
Synchronized<std::vector<TaskStatus> > failed_tasks_;
// A list of remote functions import failures.
Synchronized<std::vector<Failure> > failed_remote_function_imports_;
// A list of reusable variables import failures.
Synchronized<std::vector<Failure> > failed_reusable_variable_imports_;
// A list of reusable variables reinitialization failures.
Synchronized<std::vector<Failure> > failed_reinitialize_reusable_variables_;
// List of pending get calls.
Synchronized<std::vector<std::pair<WorkerId, ObjectID> > > get_queue_;
// The computation graph tracks the operations that have been submitted to the
+80 -33
View File
@@ -9,12 +9,14 @@ extern "C" {
static PyObject *RayError;
}
inline WorkerServiceImpl::WorkerServiceImpl(const std::string& worker_address)
: worker_address_(worker_address) {
inline WorkerServiceImpl::WorkerServiceImpl(const std::string& worker_address, Mode mode)
: worker_address_(worker_address),
mode_(mode) {
RAY_CHECK(send_queue_.connect(worker_address_, false), "error connecting send_queue_");
}
Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) {
Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, AckReply* reply) {
RAY_CHECK(mode_ == Mode::WORKER_MODE, "ExecuteTask can only be called on workers.");
RAY_LOG(RAY_INFO, "invoked task " << request->task().name());
std::unique_ptr<WorkerMessage> message(new WorkerMessage());
message->mutable_task()->CopyFrom(request->task());
@@ -26,7 +28,8 @@ Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskR
return Status::OK;
}
Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) {
Status WorkerServiceImpl::ImportRemoteFunction(ServerContext* context, const ImportRemoteFunctionRequest* request, AckReply* reply) {
RAY_CHECK(mode_ == Mode::WORKER_MODE, "ImportRemoteFunction can only be called on workers.");
std::unique_ptr<WorkerMessage> message(new WorkerMessage());
message->mutable_function()->CopyFrom(request->function());
RAY_LOG(RAY_INFO, "importing function");
@@ -39,6 +42,7 @@ Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFun
}
Status WorkerServiceImpl::ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) {
RAY_CHECK(mode_ == Mode::WORKER_MODE, "ImportReusableVariable can only be called on workers.");
std::unique_ptr<WorkerMessage> message(new WorkerMessage());
message->mutable_reusable_variable()->CopyFrom(request->reusable_variable());
RAY_LOG(RAY_INFO, "importing reusable variable");
@@ -50,18 +54,46 @@ Status WorkerServiceImpl::ImportReusableVariable(ServerContext* context, const I
return Status::OK;
}
Status WorkerServiceImpl::Die(ServerContext* context, const DieRequest* request, DieReply* reply) {
Status WorkerServiceImpl::Die(ServerContext* context, const DieRequest* request, AckReply* reply) {
RAY_CHECK(mode_ == Mode::WORKER_MODE, "Die can only be called on workers.");
WorkerMessage* message_ptr = NULL;
RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC");
return Status::OK;
}
Status WorkerServiceImpl::PrintErrorMessage(ServerContext* context, const PrintErrorMessageRequest* request, AckReply* reply) {
RAY_CHECK(mode_ != Mode::WORKER_MODE, "PrintErrorMessage can only be called on drivers.");
if (mode_ == Mode::SILENT_MODE) {
// Do not log error messages in this case. This is just used for the tests.
return Status::OK;
}
const Failure failure = request->failure();
WorkerId workerid = failure.workerid();
if (failure.type() == FailedType::FailedTask) {
// A task threw an exception while executing.
std::cout << "Error: Worker " << workerid << " failed to execute function " << failure.name() << ". Failed with error message:\n" << failure.error_message() << std::endl;
} else if (failure.type() == FailedType::FailedRemoteFunctionImport) {
// An exception was thrown while a remote function was being imported.
std::cout << "Error: Worker " << workerid << " failed to import remote function " << failure.name() << ", failed with error message:\n" << failure.error_message() << std::endl;
} else if (failure.type() == FailedType::FailedReusableVariableImport) {
// An exception was thrown while a reusable variable was being imported.
std::cout << "Error: Worker " << workerid << " failed to import reusable variable " << failure.name() << ", failed with error message:\n" << failure.error_message() << std::endl;
} else if (failure.type() == FailedType::FailedReinitializeReusableVariable) {
// An exception was thrown while a reusable variable was being reinitialized.
std::cout << "Error: Worker " << workerid << " failed to reinitialize a reusable variable after running remote function " << failure.name() << ", failed with error message:\n" << failure.error_message() << std::endl;
} else {
RAY_CHECK(false, "This code should be unreachable.")
}
return Status::OK;
}
Worker::Worker(const std::string& scheduler_address)
: scheduler_address_(scheduler_address) {
auto scheduler_channel = grpc::CreateChannel(scheduler_address, grpc::InsecureChannelCredentials());
scheduler_stub_ = Scheduler::NewStub(scheduler_channel);
}
SubmitTaskReply Worker::submit_task(SubmitTaskRequest* request, int max_retries, int retry_wait_milliseconds) {
RAY_CHECK(connected_, "Attempted to perform submit_task but failed.");
SubmitTaskReply reply;
@@ -312,15 +344,28 @@ void Worker::decrement_reference_count(std::vector<ObjectID> &objectids) {
}
}
void Worker::register_function(const std::string& name, size_t num_return_vals) {
void Worker::register_remote_function(const std::string& name, size_t num_return_vals) {
RAY_CHECK(connected_, "Attempted to perform register_function but failed.");
ClientContext context;
RegisterFunctionRequest request;
request.set_fnname(name);
request.set_num_return_vals(num_return_vals);
RegisterRemoteFunctionRequest request;
request.set_workerid(workerid_);
request.set_function_name(name);
request.set_num_return_vals(num_return_vals);
AckReply reply;
scheduler_stub_->RegisterFunction(&context, request, &reply);
scheduler_stub_->RegisterRemoteFunction(&context, request, &reply);
}
void Worker::notify_failure(FailedType type, const std::string& name, const std::string& error_message) {
RAY_CHECK(connected_, "Attempted to perform notify_failure but failed.");
ClientContext context;
NotifyFailureRequest request;
request.mutable_failure()->set_type(type);
request.mutable_failure()->set_workerid(workerid_);
request.mutable_failure()->set_worker_address(worker_address_);
request.mutable_failure()->set_name(name);
request.mutable_failure()->set_error_message(error_message);
AckReply reply;
scheduler_stub_->NotifyFailure(&context, request, &reply);
}
std::unique_ptr<WorkerMessage> Worker::receive_next_message() {
@@ -329,24 +374,19 @@ std::unique_ptr<WorkerMessage> Worker::receive_next_message() {
return std::unique_ptr<WorkerMessage>(message_ptr);
}
void Worker::notify_task_completed(bool task_succeeded, std::string error_message) {
void Worker::notify_task_completed() {
RAY_CHECK(connected_, "Attempted to perform notify_task_completed but failed.");
ClientContext context;
ReadyForNewTaskRequest request;
request.set_workerid(workerid_);
ReadyForNewTaskRequest::PreviousTaskInfo* previous_task_info = request.mutable_previous_task_info();
previous_task_info->set_task_succeeded(task_succeeded);
previous_task_info->set_error_message(error_message);
AckReply reply;
scheduler_stub_->ReadyForNewTask(&context, request, &reply);
}
void Worker::disconnect() {
connected_ = false;
}
bool Worker::connected() {
return connected_;
// TODO(rkn): This probably isn't the right way to clean up the thread.
worker_server_thread_->detach();
}
// TODO(rkn): Should we be using pointers or references? And should they be const?
@@ -360,13 +400,14 @@ void Worker::task_info(ClientContext &context, TaskInfoRequest &request, TaskInf
scheduler_stub_->TaskInfo(&context, request, &reply);
}
bool Worker::export_function(const std::string& function) {
bool Worker::export_remote_function(const std::string& function_name, const std::string& function) {
RAY_CHECK(connected_, "Attempted to export function but failed.");
ClientContext context;
ExportFunctionRequest request;
ExportRemoteFunctionRequest request;
request.mutable_function()->set_name(function_name);
request.mutable_function()->set_implementation(function);
ExportFunctionReply reply;
Status status = scheduler_stub_->ExportFunction(&context, request, &reply);
AckReply reply;
Status status = scheduler_stub_->ExportRemoteFunction(&context, request, &reply);
return true;
}
@@ -385,26 +426,32 @@ void Worker::export_reusable_variable(const std::string& name, const std::string
// queue. This is because the Python interpreter needs to be single threaded
// (in our case running in the main thread), whereas the WorkerService will
// run in a separate thread and potentially utilize multiple threads.
void Worker::start_worker_service() {
void Worker::start_worker_service(Mode mode) {
const char* service_addr = worker_address_.c_str();
worker_server_thread_ = std::thread([this, service_addr]() {
// Launch a new thread for running the worker service. We store this as a
// field so that we can clean it up when we disconnect the worker.
worker_server_thread_ = std::unique_ptr<std::thread>(new std::thread([this, service_addr, mode]() {
std::string service_address(service_addr);
std::string::iterator split_point = split_ip_address(service_address);
std::string port;
port.assign(split_point, service_address.end());
WorkerServiceImpl service(service_address);
// Create the worker service.
WorkerServiceImpl service(service_address, mode);
ServerBuilder builder;
builder.AddListeningPort(std::string("0.0.0.0:") + port, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
std::unique_ptr<Server> server(builder.BuildAndStart());
RAY_LOG(RAY_INFO, "worker server listening on " << service_address);
ClientContext context;
ReadyForNewTaskRequest request;
request.set_workerid(workerid_);
AckReply reply;
scheduler_stub_->ReadyForNewTask(&context, request, &reply);
// If this is part of a worker process (and not a driver process), then tell
// the scheduler that it is ready to start receiving tasks.
if (mode == Mode::WORKER_MODE) {
ClientContext context;
ReadyForNewTaskRequest request;
request.set_workerid(workerid_);
AckReply reply;
scheduler_stub_->ReadyForNewTask(&context, request, &reply);
}
// Wait for work and process work, this does not return.
server->Wait();
});
}));
}
+27 -15
View File
@@ -23,16 +23,25 @@ using grpc::Channel;
using grpc::ClientContext;
using grpc::ClientWriter;
// These three constants are used to define the mode that a worker is running
// in. Right now, this is mostly used for determining how to print information
// about task failures.
enum Mode {SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE};
class WorkerServiceImpl final : public WorkerService::Service {
public:
WorkerServiceImpl(const std::string& worker_address);
Status ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) override;
Status ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) override;
Status Die(ServerContext* context, const DieRequest* request, DieReply* reply) override;
WorkerServiceImpl(const std::string& worker_address, Mode mode);
Status ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, AckReply* reply) override;
Status ImportRemoteFunction(ServerContext* context, const ImportRemoteFunctionRequest* request, AckReply* reply) override;
Status Die(ServerContext* context, const DieRequest* request, AckReply* reply) override;
Status ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) override;
Status PrintErrorMessage(ServerContext* context, const PrintErrorMessageRequest* request, AckReply* reply) override;
private:
std::string worker_address_;
MessageQueue<WorkerMessage*> send_queue_;
// This is true if the worker service is part of a driver process and false
// if it is part of a worker process.
Mode mode_;
};
class Worker {
@@ -71,27 +80,30 @@ class Worker {
void increment_reference_count(std::vector<ObjectID> &objectid);
// decrement the reference count for objectid
void decrement_reference_count(std::vector<ObjectID> &objectid);
// register function with scheduler
void register_function(const std::string& name, size_t num_return_vals);
// start the worker server which accepts tasks from the scheduler and stores
// it in the message queue, which is read by the Python interpreter
void start_worker_service();
// Notify the scheduler that a remote function has been imported successfully.
void register_remote_function(const std::string& name, size_t num_return_vals);
// Notify the scheduler that a failure has occurred.
void notify_failure(FailedType type, const std::string& name, const std::string& error_message);
// Start the worker server which accepts commands from the scheduler. For
// workers, these commands are stored in the message queue, which is read by
// the Python interpreter. For drivers, these commands are only for printing
// error messages.
void start_worker_service(Mode mode);
// wait for next task from the RPC system. If null, it means there are no more tasks and the worker should shut down.
std::unique_ptr<WorkerMessage> receive_next_message();
// tell the scheduler that we are done with the current task and request the
// next one, if task_succeeded is false, this tells the scheduler that the
// task threw an exception
void notify_task_completed(bool task_succeeded, std::string error_message);
// next one.
void notify_task_completed();
// disconnect the worker
void disconnect();
// return connected_
bool connected();
bool connected() { return connected_; }
// get info about scheduler state
void scheduler_info(ClientContext &context, SchedulerInfoRequest &request, SchedulerInfoReply &reply);
// get task statuses from scheduler
void task_info(ClientContext &context, TaskInfoRequest &request, TaskInfoReply &reply);
// export function to workers
bool export_function(const std::string& function);
bool export_remote_function(const std::string& function_name, const std::string& function);
// export reusable variable to workers
void export_reusable_variable(const std::string& name, const std::string& initializer, const std::string& reinitializer);
// return the worker address
@@ -101,7 +113,7 @@ class Worker {
bool connected_;
const size_t CHUNK_SIZE = 8 * 1024;
std::unique_ptr<Scheduler::Stub> scheduler_stub_;
std::thread worker_server_thread_;
std::unique_ptr<std::thread> worker_server_thread_;
MessageQueue<WorkerMessage*> receive_queue_;
bip::managed_shared_memory segment_;
WorkerId workerid_;
+57 -3
View File
@@ -249,14 +249,12 @@ class APITest(unittest.TestCase):
task_info = ray.task_info()
self.assertEqual(len(task_info["failed_tasks"]), 0)
self.assertEqual(len(task_info["running_tasks"]), 0)
self.assertEqual(task_info["num_succeeded"], 1)
test_functions.no_op_fail.remote()
time.sleep(0.2)
task_info = ray.task_info()
self.assertEqual(len(task_info["failed_tasks"]), 1)
self.assertEqual(len(task_info["running_tasks"]), 0)
self.assertEqual(task_info["num_succeeded"], 1)
self.assertTrue("The @remote decorator for function test_functions.no_op_fail has 0 return values, but test_functions.no_op_fail returned more than 0 values." in task_info["failed_tasks"][0].get("error_message"))
ray.worker.cleanup()
@@ -273,7 +271,6 @@ class APITest(unittest.TestCase):
task_info = ray.task_info()
self.assertEqual(len(task_info["failed_tasks"]), 2)
self.assertEqual(len(task_info["running_tasks"]), 0)
self.assertEqual(task_info["num_succeeded"], 0)
ray.worker.cleanup()
@@ -391,6 +388,63 @@ class TaskStatusTest(unittest.TestCase):
ray.worker.cleanup()
def testFailImportingRemoteFunction(self):
ray.init(start_ray_local=True, num_workers=2, driver_mode=ray.SILENT_MODE)
# This example is somewhat contrived. It should be successfully pickled, and
# then it should throw an exception when it is unpickled. This may depend a
# bit on the specifics of our pickler.
def reducer(*args):
raise Exception("There is a problem here.")
class Foo(object):
def __init__(self):
self.__name__ = "Foo_object"
self.func_doc = ""
self.__globals__ = {}
def __reduce__(self):
return reducer, ()
def __call__(self):
return
ray.remote([], [])(Foo())
time.sleep(0.1)
self.assertTrue("There is a problem here." in ray.task_info()["failed_remote_function_imports"][0]["error_message"])
ray.worker.cleanup()
def testFailImportingReusableVariable(self):
ray.init(start_ray_local=True, num_workers=2, driver_mode=ray.SILENT_MODE)
# This will throw an exception when the reusable variable is imported on the
# workers.
def initializer():
if ray.worker.global_worker.mode == ray.WORKER_MODE:
raise Exception("The initializer failed.")
return 0
ray.reusables.foo = ray.Reusable(initializer)
time.sleep(0.1)
# Check that the error message is in the task info.
self.assertTrue("The initializer failed." in ray.task_info()["failed_reusable_variable_imports"][0]["error_message"])
ray.worker.cleanup()
def testFailReinitializingVariable(self):
ray.init(start_ray_local=True, num_workers=2, driver_mode=ray.SILENT_MODE)
def initializer():
return 0
def reinitializer(foo):
raise Exception("The reinitializer failed.")
ray.reusables.foo = ray.Reusable(initializer, reinitializer)
@ray.remote([], [])
def use_foo():
ray.reusables.foo
use_foo.remote()
time.sleep(0.1)
# Check that the error message is in the task info.
self.assertTrue("The reinitializer failed." in ray.task_info()["failed_reinitialize_reusable_variables"][0]["error_message"])
ray.worker.cleanup()
def check_get_deallocated(data):
x = ray.put(data)
ray.get(x)