mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 22:42:17 +08:00
Catch errors in importing reusable variables and remote functions (#354)
* catch errors in importing reusable variables and remote functions * updates
This commit is contained in:
committed by
Philipp Moritz
parent
a6452aca47
commit
a1e4268d37
@@ -11,8 +11,8 @@ if hasattr(ctypes, "windll"):
|
||||
|
||||
import config
|
||||
import serialization
|
||||
from worker import scheduler_info, visualize_computation_graph, task_info, register_module, init, connect, disconnect, get, put, remote, kill_workers, restart_workers_local
|
||||
from worker import scheduler_info, visualize_computation_graph, task_info, init, connect, disconnect, get, put, remote, kill_workers, restart_workers_local
|
||||
from worker import Reusable, reusables
|
||||
from worker import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE
|
||||
from libraylib import SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE
|
||||
from libraylib import ObjectID
|
||||
import internal
|
||||
|
||||
+112
-111
@@ -20,14 +20,6 @@ import services
|
||||
import libnumbuf
|
||||
import libraylib as raylib
|
||||
|
||||
# These three constants are used to define the mode that a worker is running in.
|
||||
# Right now, this is only used for determining how to print information about
|
||||
# task failures.
|
||||
SCRIPT_MODE = 0
|
||||
WORKER_MODE = 1
|
||||
PYTHON_MODE = 2
|
||||
SILENT_MODE = 3 # This is only used during testing.
|
||||
|
||||
class RayTaskError(Exception):
|
||||
"""An object used internally to represent a task that threw an exception.
|
||||
|
||||
@@ -341,7 +333,7 @@ class RayReusables(object):
|
||||
raise Exception("To set a reusable variable, you must pass in a Reusable object")
|
||||
self._names.add(name)
|
||||
self._reusables[name] = reusable
|
||||
if _mode() in [SCRIPT_MODE, SILENT_MODE]:
|
||||
if _mode() in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
_export_reusable_variable(name, reusable)
|
||||
elif _mode() is None:
|
||||
self._cached_reusables.append((name, reusable))
|
||||
@@ -369,11 +361,13 @@ class Worker(object):
|
||||
handle (worker capsule): A Python object wrapping a C++ Worker object.
|
||||
mode: The mode of the worker. One of SCRIPT_MODE, PYTHON_MODE, SILENT_MODE,
|
||||
and WORKER_MODE.
|
||||
cached_remote_functions (List[str]): A list of serialized remote functions
|
||||
that were defined before the worker called connect. When the worker
|
||||
cached_remote_functions (List[Tuple[str, str]]): A list of pairs
|
||||
representing the remote functions that were defined before he worker
|
||||
called connect. The first element is the name of the remote function, and
|
||||
the second element is the serialized remote function. When the worker
|
||||
eventually does call connect, if it is a driver, it will export these
|
||||
functions to the scheduler. If cached_remote_functions is None, that means
|
||||
that connect has been called.
|
||||
that connect has been called already.
|
||||
num_failed_tasks (int): The number of tasks that have failed and whose error
|
||||
messages have been displayed to the user. We use this value to know when
|
||||
a failed task hasn't been seen by the user and should be displayed.
|
||||
@@ -507,20 +501,6 @@ class Worker(object):
|
||||
"""Make two object IDs refer to the same object."""
|
||||
raylib.alias_objectids(self.handle, alias_objectid, target_objectid)
|
||||
|
||||
def register_function(self, function):
|
||||
"""Register a function with the scheduler.
|
||||
|
||||
Notify the scheduler that this worker can execute the function with name
|
||||
func_name. After this call, the scheduler can send tasks for executing
|
||||
the function to this worker.
|
||||
|
||||
Args:
|
||||
function (Callable): The remote function that this worker can execute.
|
||||
"""
|
||||
_logger().info("Registering function {}.".format(function.func_name))
|
||||
raylib.register_function(self.handle, function.func_name, len(function.return_types))
|
||||
self.functions[function.func_name] = function
|
||||
|
||||
def submit_task(self, func_name, args):
|
||||
"""Submit a remote task to the scheduler.
|
||||
|
||||
@@ -536,23 +516,8 @@ class Worker(object):
|
||||
"""
|
||||
task_capsule = serialization.serialize_task(self.handle, func_name, args)
|
||||
objectids = raylib.submit_task(self.handle, task_capsule)
|
||||
if self.mode == SCRIPT_MODE:
|
||||
self.print_new_failures()
|
||||
return objectids
|
||||
|
||||
def print_new_failures(self):
|
||||
"""Print information about tasks."""
|
||||
task_data = raylib.task_info(self.handle)
|
||||
num_tasks_succeeded = task_data["num_succeeded"]
|
||||
num_tasks_in_progress = len(task_data["running_tasks"])
|
||||
num_new_tasks_failed = len(task_data["failed_tasks"]) - self.num_failed_tasks
|
||||
if num_new_tasks_failed > 0:
|
||||
# Print the new tasks that have failed.
|
||||
for task_status in task_data["failed_tasks"][self.num_failed_tasks:]:
|
||||
print_failed_task(task_status)
|
||||
print "{}Error: {} new task{} failed.{}".format(colorama.Fore.RED, num_new_tasks_failed, "s" if num_new_tasks_failed > 1 else "", colorama.Fore.RESET)
|
||||
self.num_failed_tasks = len(task_data["failed_tasks"])
|
||||
|
||||
global_worker = Worker()
|
||||
"""Worker: The global Worker object for this worker process.
|
||||
|
||||
@@ -643,24 +608,7 @@ def task_info(worker=global_worker):
|
||||
check_connected(worker)
|
||||
return raylib.task_info(worker.handle)
|
||||
|
||||
def register_module(module, worker=global_worker):
|
||||
"""Register each remote function in a module with the scheduler.
|
||||
|
||||
This registers each remote function in the module with the scheduler, so tasks
|
||||
with those functions can be scheduled on this worker.
|
||||
|
||||
args:
|
||||
module (module): The module of functions to register.
|
||||
"""
|
||||
check_connected(worker)
|
||||
_logger().info("registering functions in module {}.".format(module.__name__))
|
||||
for name in dir(module):
|
||||
val = getattr(module, name)
|
||||
if hasattr(val, "is_remote") and val.is_remote:
|
||||
_logger().info("registering {}.".format(val.func_name))
|
||||
worker.register_function(val)
|
||||
|
||||
def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_address=None, node_ip_address=None, driver_mode=SCRIPT_MODE):
|
||||
def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_address=None, node_ip_address=None, driver_mode=raylib.SCRIPT_MODE):
|
||||
"""Either connect to an existing Ray cluster or start one and connect to it.
|
||||
|
||||
This method handles two cases. Either a Ray cluster already exists and we
|
||||
@@ -692,8 +640,8 @@ def init(start_ray_local=False, num_workers=None, num_objstores=None, scheduler_
|
||||
# and we connect to them.
|
||||
if (scheduler_address is not None) or (node_ip_address is not None):
|
||||
raise Exception("If start_ray_local=True, then you cannot pass in a scheduler_address or a node_ip_address.")
|
||||
if driver_mode not in [SCRIPT_MODE, PYTHON_MODE, SILENT_MODE]:
|
||||
raise Exception("If start_ray_local=True, then driver_mode must be in [SCRIPT_MODE, PYTHON_MODE, SILENT_MODE].")
|
||||
if driver_mode not in [raylib.SCRIPT_MODE, raylib.PYTHON_MODE, raylib.SILENT_MODE]:
|
||||
raise Exception("If start_ray_local=True, then driver_mode must be in [ray.SCRIPT_MODE, ray.PYTHON_MODE, ray.SILENT_MODE].")
|
||||
# Use the address 127.0.0.1 in local mode.
|
||||
node_ip_address = "127.0.0.1"
|
||||
num_workers = 1 if num_workers is None else num_workers
|
||||
@@ -727,7 +675,7 @@ def cleanup(worker=global_worker):
|
||||
|
||||
atexit.register(cleanup)
|
||||
|
||||
def connect(node_ip_address, scheduler_address, objstore_address=None, is_driver=False, worker=global_worker, mode=WORKER_MODE):
|
||||
def connect(node_ip_address, scheduler_address, objstore_address=None, is_driver=False, worker=global_worker, mode=raylib.WORKER_MODE):
|
||||
"""Connect this worker to the scheduler and an object store.
|
||||
|
||||
Args:
|
||||
@@ -757,13 +705,19 @@ def connect(node_ip_address, scheduler_address, objstore_address=None, is_driver
|
||||
_logger().propagate = False
|
||||
# Configure the logging from the worker C++ code.
|
||||
raylib.set_log_config(config.get_log_file_path("-".join(["worker", worker.worker_address, "c++"]) + ".log"))
|
||||
if mode in [SCRIPT_MODE, SILENT_MODE]:
|
||||
for function_to_export in worker.cached_remote_functions:
|
||||
raylib.export_function(worker.handle, function_to_export)
|
||||
if mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
for function_name, function_to_export in worker.cached_remote_functions:
|
||||
raylib.export_remote_function(worker.handle, function_name, function_to_export)
|
||||
for name, reusable_variable in reusables._cached_reusables:
|
||||
_export_reusable_variable(name, reusable_variable)
|
||||
worker.cached_remote_functions = None
|
||||
reusables._cached_reusables = None
|
||||
# Start the driver's WorkerService (if this is a driver). This will receive
|
||||
# GRPC commands from the scheduler to print error messages. We pass in the
|
||||
# mode below. This tells the WorkerService whether it is operating for a
|
||||
# driver or a worker and whether it should surpress errors or not.
|
||||
if is_driver:
|
||||
raylib.start_worker_service(worker.handle, mode)
|
||||
|
||||
def disconnect(worker=global_worker):
|
||||
"""Disconnect this worker from the scheduler and object store."""
|
||||
@@ -791,11 +745,9 @@ def get(objectid, worker=global_worker):
|
||||
A Python object
|
||||
"""
|
||||
check_connected(worker)
|
||||
if worker.mode == PYTHON_MODE:
|
||||
return objectid # In PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid)
|
||||
if worker.mode == raylib.PYTHON_MODE:
|
||||
return objectid # In raylib.PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid)
|
||||
raylib.request_object(worker.handle, objectid)
|
||||
if worker.mode == SCRIPT_MODE:
|
||||
worker.print_new_failures()
|
||||
value = worker.get_object(objectid)
|
||||
if isinstance(value, RayTaskError):
|
||||
# If the result is a RayTaskError, then the task that created this object
|
||||
@@ -813,12 +765,10 @@ def put(value, worker=global_worker):
|
||||
The object ID assigned to this value.
|
||||
"""
|
||||
check_connected(worker)
|
||||
if worker.mode == PYTHON_MODE:
|
||||
return value # In PYTHON_MODE, ray.put is the identity operation
|
||||
if worker.mode == raylib.PYTHON_MODE:
|
||||
return value # In raylib.PYTHON_MODE, ray.put is the identity operation
|
||||
objectid = raylib.get_objectid(worker.handle)
|
||||
worker.put_object(objectid, value)
|
||||
if worker.mode == SCRIPT_MODE:
|
||||
worker.print_new_failures()
|
||||
return objectid
|
||||
|
||||
def kill_workers(worker=global_worker):
|
||||
@@ -878,27 +828,33 @@ def format_error_message(exception_message):
|
||||
def main_loop(worker=global_worker):
|
||||
"""The main loop a worker runs to receive and execute tasks.
|
||||
|
||||
This method is an infinite loop. It waits to receive tasks from the scheduler.
|
||||
When it receives a task, it first deserializes the task. Then it retrieves the
|
||||
values for any arguments that were passed in as object IDs. Then it
|
||||
passes the arguments to the actual function. Then it stores the outputs of the
|
||||
function in the local object store. Then it notifies the scheduler that it
|
||||
completed the task.
|
||||
|
||||
If the process of getting the arguments for execution (which does some type
|
||||
checking) or the process of executing the task fail, then the main loop will
|
||||
catch the exception and store RayTaskError objects containing the relevant
|
||||
error messages in the object store in place of the actual outputs. These
|
||||
objects are used to propagate the error messages.
|
||||
This method is an infinite loop. It waits to receive commands from the
|
||||
scheduler. A command may consist of a task to execute, a remote function to
|
||||
import, a reusable variable to import, or an order to terminate the worker
|
||||
process. The worker executes the command, notifies the scheduler of any errors
|
||||
that occurred while executing the command, and waits for the next command.
|
||||
"""
|
||||
if not raylib.connected(worker.handle):
|
||||
raise Exception("Worker is attempting to enter main_loop but has not been connected yet.")
|
||||
raylib.start_worker_service(worker.handle)
|
||||
# We pass in raylib.WORKER_MODE below to indicate that the WorkerService is
|
||||
# operating for a worker and not a driver.
|
||||
raylib.start_worker_service(worker.handle, raylib.WORKER_MODE)
|
||||
|
||||
def process_task(task): # wrapping these lines in a function should cause the local variables to go out of scope more quickly, which is useful for inspecting reference counts
|
||||
func_name, args, return_objectids = serialization.deserialize_task(worker.handle, task)
|
||||
"""Execute a task assigned to this worker.
|
||||
|
||||
This method deserializes a task from the scheduler, and attempts to execute
|
||||
the task. If the task succeeds, the outputs are stored in the local object
|
||||
store. If the task throws an exception, RayTaskError objects are stored in
|
||||
the object store to represent the failed task (these will be retrieved by
|
||||
calls to get or by subsequent tasks that use the outputs of this task).
|
||||
After the task executes, the worker resets any reusable variables that were
|
||||
accessed by the task.
|
||||
"""
|
||||
function_name, args, return_objectids = serialization.deserialize_task(worker.handle, task)
|
||||
try:
|
||||
arguments = get_arguments_for_execution(worker.functions[func_name], args, worker) # get args from objstore
|
||||
outputs = worker.functions[func_name].executor(arguments) # execute the function
|
||||
arguments = get_arguments_for_execution(worker.functions[function_name], args, worker) # get args from objstore
|
||||
outputs = worker.functions[function_name].executor(arguments) # execute the function
|
||||
if len(return_objectids) == 1:
|
||||
outputs = (outputs,)
|
||||
except Exception as e:
|
||||
@@ -906,37 +862,82 @@ def main_loop(worker=global_worker):
|
||||
# whether the exception was thrown in the task execution by whether the
|
||||
# variable "arguments" is defined.
|
||||
traceback_str = format_error_message(traceback.format_exc()) if "arguments" in locals() else None
|
||||
failure_object = RayTaskError(func_name, e, traceback_str)
|
||||
failure_object = RayTaskError(function_name, e, traceback_str)
|
||||
failure_objects = [failure_object for _ in range(len(return_objectids))]
|
||||
store_outputs_in_objstore(return_objectids, failure_objects, worker)
|
||||
raylib.notify_task_completed(worker.handle, False, str(failure_object))
|
||||
_logger().info("Worker threw exception with message: \n\n{}\n, while running function {}.".format(str(failure_object), func_name))
|
||||
# Notify the scheduler that the task failed.
|
||||
raylib.notify_failure(worker.handle, function_name, str(failure_object), raylib.FailedTask)
|
||||
_logger().info("While running function {}, worker threw exception with message: \n\n{}\n".format(function_name, str(failure_object)))
|
||||
else:
|
||||
store_outputs_in_objstore(return_objectids, outputs, worker) # store output in local object store
|
||||
raylib.notify_task_completed(worker.handle, True, "") # notify the scheduler that the task completed successfully
|
||||
finally:
|
||||
# Notify the scheduler that the task is done. This happens regardless of
|
||||
# whether the task succeeded or failed.
|
||||
raylib.notify_task_completed(worker.handle)
|
||||
try:
|
||||
# Reinitialize the values of reusable variables that were used in the task
|
||||
# above so that changes made to their state do not affect other tasks.
|
||||
reusables._reinitialize()
|
||||
except Exception as e:
|
||||
# The attempt to reinitialize the reusable variables threw an exception.
|
||||
# We record the traceback and notify the scheduler.
|
||||
traceback_str = format_error_message(traceback.format_exc())
|
||||
raylib.notify_failure(worker.handle, function_name, traceback_str, raylib.FailedReinitializeReusableVariable)
|
||||
_logger().info("While attempting to reinitialize the reusable variables after running function {}, the worker threw exception with message: \n\n{}\n".format(function_name, traceback_str))
|
||||
|
||||
def process_remote_function(function_name, serialized_function):
|
||||
"""Import a remote function."""
|
||||
try:
|
||||
(function, arg_types, return_types, module) = pickling.loads(serialized_function)
|
||||
except:
|
||||
# If an exception was thrown when the remote function was imported, we
|
||||
# record the traceback and notify the scheduler of the failure.
|
||||
traceback_str = format_error_message(traceback.format_exc())
|
||||
_logger().info("Failed to import remote function {}. Failed with message: \n\n{}\n".format(function_name, traceback_str))
|
||||
# Notify the scheduler that the remote function failed to import.
|
||||
raylib.notify_failure(worker.handle, function_name, traceback_str, raylib.FailedRemoteFunctionImport)
|
||||
else:
|
||||
# TODO(rkn): Why is the below line necessary?
|
||||
function.__module__ = module
|
||||
assert function_name == "{}.{}".format(function.__module__, function.__name__), "The remote function name does not match the name that was passed in."
|
||||
worker.functions[function_name] = remote(arg_types, return_types, worker)(function)
|
||||
_logger().info("Successfully imported remote function {}.".format(function_name))
|
||||
# Noify the scheduler that the remote function imported successfully.
|
||||
# We pass an empty error message string because the import succeeded.
|
||||
raylib.register_remote_function(worker.handle, function_name, len(return_types))
|
||||
|
||||
def process_reusable_variable(reusable_variable_name, initializer_str, reinitializer_str):
|
||||
"""Import a reusable variable."""
|
||||
try:
|
||||
initializer = pickling.loads(initializer_str)
|
||||
reinitializer = pickling.loads(reinitializer_str)
|
||||
reusables.__setattr__(reusable_variable_name, Reusable(initializer, reinitializer))
|
||||
except:
|
||||
# If an exception was thrown when the reusable variable was imported, we
|
||||
# record the traceback and notify the scheduler of the failure.
|
||||
traceback_str = format_error_message(traceback.format_exc())
|
||||
_logger().info("Failed to import reusable variable {}. Failed with message: \n\n{}\n".format(reusable_variable_name, traceback_str))
|
||||
# Notify the scheduler that the reusable variable failed to import.
|
||||
raylib.notify_failure(worker.handle, reusable_variable_name, traceback_str, raylib.FailedReusableVariableImport)
|
||||
else:
|
||||
_logger().info("Successfully imported reusable variable {}.".format(reusable_variable_name))
|
||||
|
||||
while True:
|
||||
command, command_args = raylib.wait_for_next_message(worker.handle)
|
||||
try:
|
||||
if command == "die":
|
||||
# We use this as a mechanism to allow the scheduler to kill workers.
|
||||
_logger().info("Received a 'die' command, and will exit now.")
|
||||
break
|
||||
elif command == "function":
|
||||
(function, arg_types, return_types, module) = pickling.loads(command_args)
|
||||
# TODO(rkn): Why is the below line necessary?
|
||||
function.__module__ = module
|
||||
worker.register_function(remote(arg_types, return_types, worker)(function))
|
||||
elif command == "reusable_variable":
|
||||
name, initializer_str, reinitializer_str = command_args
|
||||
initializer = pickling.loads(initializer_str)
|
||||
reinitializer = pickling.loads(reinitializer_str)
|
||||
reusables.__setattr__(name, Reusable(initializer, reinitializer))
|
||||
elif command == "task":
|
||||
process_task(command_args)
|
||||
elif command == "function":
|
||||
function_name, serialized_function = command_args
|
||||
process_remote_function(function_name, serialized_function)
|
||||
elif command == "reusable_variable":
|
||||
name, initializer_str, reinitializer_str = command_args
|
||||
process_reusable_variable(name, initializer_str, reinitializer_str)
|
||||
else:
|
||||
_logger().info("Reached the end of the if-else loop in the main loop. This should be unreachable.")
|
||||
assert False, "This code should be unreachable."
|
||||
finally:
|
||||
# Allow releasing the variables BEFORE we wait for the next message or exit the block
|
||||
@@ -979,7 +980,7 @@ def _export_reusable_variable(name, reusable, worker=global_worker):
|
||||
reusable (Reusable): The reusable object containing code for initializing
|
||||
and reinitializing the variable.
|
||||
"""
|
||||
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
|
||||
if _mode(worker) not in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
raise Exception("_export_reusable_variable can only be called on a driver.")
|
||||
raylib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer))
|
||||
|
||||
@@ -996,8 +997,8 @@ def remote(arg_types, return_types, worker=global_worker):
|
||||
check_connected()
|
||||
args = list(args)
|
||||
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments
|
||||
if _mode() == PYTHON_MODE:
|
||||
# In PYTHON_MODE, remote calls simply execute the function. We copy the
|
||||
if _mode() == raylib.PYTHON_MODE:
|
||||
# In raylib.PYTHON_MODE, remote calls simply execute the function. We copy the
|
||||
# arguments to prevent the function call from mutating them and to match
|
||||
# the usual behavior of immutable remote objects.
|
||||
return func(*copy.deepcopy(args))
|
||||
@@ -1035,7 +1036,7 @@ def remote(arg_types, return_types, worker=global_worker):
|
||||
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
|
||||
|
||||
# Everything ready - export the function
|
||||
if worker.mode in [None, SCRIPT_MODE, SILENT_MODE]:
|
||||
if worker.mode in [None, raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
func_name_global_valid = func.__name__ in func.__globals__
|
||||
func_name_global_value = func.__globals__.get(func.__name__)
|
||||
# Set the function globally to make it refer to itself
|
||||
@@ -1046,10 +1047,10 @@ def remote(arg_types, return_types, worker=global_worker):
|
||||
# Undo our changes
|
||||
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
|
||||
else: del func.__globals__[func.__name__]
|
||||
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
|
||||
raylib.export_function(worker.handle, to_export)
|
||||
if worker.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
raylib.export_remote_function(worker.handle, func_name, to_export)
|
||||
elif worker.mode is None:
|
||||
worker.cached_remote_functions.append(to_export)
|
||||
worker.cached_remote_functions.append((func_name, to_export))
|
||||
return func_invoker
|
||||
return remote_decorator
|
||||
|
||||
|
||||
+31
-32
@@ -22,8 +22,8 @@ service Scheduler {
|
||||
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerReply);
|
||||
// Register an object store with the scheduler
|
||||
rpc RegisterObjStore(RegisterObjStoreRequest) returns (RegisterObjStoreReply);
|
||||
// Tell the scheduler that a worker can execute a certain function
|
||||
rpc RegisterFunction(RegisterFunctionRequest) returns (AckReply);
|
||||
// Tell the scheduler that a worker successfully imported a remote function.
|
||||
rpc RegisterRemoteFunction(RegisterRemoteFunctionRequest) returns (AckReply);
|
||||
// Asks the scheduler to execute a task, immediately returns an object ID to the result
|
||||
rpc SubmitTask(SubmitTaskRequest) returns (SubmitTaskReply);
|
||||
// Increment the count of the object ID
|
||||
@@ -53,9 +53,11 @@ service Scheduler {
|
||||
// Kills the workers
|
||||
rpc KillWorkers(KillWorkersRequest) returns (KillWorkersReply);
|
||||
// Exports function to the workers
|
||||
rpc ExportFunction(ExportFunctionRequest) returns (ExportFunctionReply);
|
||||
rpc ExportRemoteFunction(ExportRemoteFunctionRequest) returns (AckReply);
|
||||
// Ship an initializer and reinitializer for a reusable variable to the workers
|
||||
rpc ExportReusableVariable(ExportReusableVariableRequest) returns (AckReply);
|
||||
// Notify the scheduler that a failure occurred while running a task, importing a remote function, or importing a reusable variable.
|
||||
rpc NotifyFailure(NotifyFailureRequest) returns (AckReply);
|
||||
}
|
||||
|
||||
message AckReply {
|
||||
@@ -82,10 +84,14 @@ message RegisterObjStoreReply {
|
||||
uint64 objstoreid = 1; // Object store ID assigned by the scheduler
|
||||
}
|
||||
|
||||
message RegisterFunctionRequest {
|
||||
message RegisterRemoteFunctionRequest {
|
||||
uint64 workerid = 1; // Worker that can execute the function
|
||||
string fnname = 2; // Name of the function that is registered
|
||||
uint64 num_return_vals = 3; // Number of return values of the function
|
||||
string function_name = 2; // Name of the remote function
|
||||
uint64 num_return_vals = 3; // Number of return values of the function. This is only present if the function was successfully imported.
|
||||
}
|
||||
|
||||
message NotifyFailure {
|
||||
Failure failure = 1; // The failure object.
|
||||
}
|
||||
|
||||
message SubmitTaskRequest {
|
||||
@@ -136,11 +142,6 @@ message DecrementRefCountRequest {
|
||||
|
||||
message ReadyForNewTaskRequest {
|
||||
uint64 workerid = 1; // ID of the worker which executed the task
|
||||
message PreviousTaskInfo {
|
||||
bool task_succeeded = 1; // True if the task succeeded, false if it threw an exception
|
||||
string error_message = 2; // The contents of the exception, if the task threw an exception
|
||||
}
|
||||
PreviousTaskInfo previous_task_info = 2; // Information about the previous task, this is only present if there was a previous task
|
||||
}
|
||||
|
||||
message ChangeCountRequest {
|
||||
@@ -221,10 +222,11 @@ message TaskInfoRequest {
|
||||
}
|
||||
|
||||
message TaskInfoReply {
|
||||
repeated TaskStatus failed_task = 1;
|
||||
repeated TaskStatus running_task = 2;
|
||||
uint64 num_succeeded = 3;
|
||||
// TODO(mehrdadn): We'll want to return information from computation_graph since it's important for visualizing tasks that have been completed etc.
|
||||
repeated TaskStatus failed_task = 1; // The tasks that have failed.
|
||||
repeated TaskStatus running_task = 2; // The tasks that are currently running.
|
||||
repeated Failure failed_remote_function_import = 3; // The remote function imports that failed.
|
||||
repeated Failure failed_reusable_variable_import = 4; // The reusable variable imports that failed.
|
||||
repeated Failure failed_reinitialize_reusable_variable = 5; // The reusable variable reinitializations that failed.
|
||||
}
|
||||
|
||||
message KillWorkersRequest {
|
||||
@@ -234,17 +236,18 @@ message KillWorkersReply {
|
||||
bool success = 1; // Currently, the only reason to fail is if there are workers still executing tasks
|
||||
}
|
||||
|
||||
message ExportFunctionRequest {
|
||||
message ExportRemoteFunctionRequest {
|
||||
Function function = 1;
|
||||
}
|
||||
|
||||
message ExportFunctionReply {
|
||||
}
|
||||
|
||||
message ExportReusableVariableRequest {
|
||||
ReusableVar reusable_variable = 1; // The reusable variable to export.
|
||||
}
|
||||
|
||||
message NotifyFailureRequest {
|
||||
Failure failure = 1; // The failure object.
|
||||
}
|
||||
|
||||
// These messages are for getting information about the object store state
|
||||
|
||||
message ObjStoreInfoRequest {
|
||||
@@ -259,26 +262,21 @@ message ObjStoreInfoReply {
|
||||
// Workers
|
||||
|
||||
service WorkerService {
|
||||
rpc ExecuteTask(ExecuteTaskRequest) returns (ExecuteTaskReply); // Scheduler calls a function from the worker
|
||||
rpc ImportFunction(ImportFunctionRequest) returns (ImportFunctionReply); // Scheduler imports a function into the worker
|
||||
rpc ExecuteTask(ExecuteTaskRequest) returns (AckReply); // Scheduler calls a function from the worker
|
||||
rpc ImportRemoteFunction(ImportRemoteFunctionRequest) returns (AckReply); // Scheduler imports a function into the worker
|
||||
rpc ImportReusableVariable(ImportReusableVariableRequest) returns (AckReply); // Scheduler imports a reusable variable into the worker
|
||||
rpc Die(DieRequest) returns (DieReply); // Kills this worker
|
||||
rpc Die(DieRequest) returns (AckReply); // Kills this worker
|
||||
rpc PrintErrorMessage(PrintErrorMessageRequest) returns (AckReply); // Causes an error message to be printed.
|
||||
}
|
||||
|
||||
message ExecuteTaskRequest {
|
||||
Task task = 1; // Contains name of the function to be executed and arguments
|
||||
}
|
||||
|
||||
message ExecuteTaskReply {
|
||||
}
|
||||
|
||||
message ImportFunctionRequest {
|
||||
message ImportRemoteFunctionRequest {
|
||||
Function function = 1;
|
||||
}
|
||||
|
||||
message ImportFunctionReply {
|
||||
}
|
||||
|
||||
message ImportReusableVariableRequest {
|
||||
ReusableVar reusable_variable = 1; // The reusable variable to export.
|
||||
}
|
||||
@@ -286,9 +284,6 @@ message ImportReusableVariableRequest {
|
||||
message DieRequest {
|
||||
}
|
||||
|
||||
message DieReply {
|
||||
}
|
||||
|
||||
// This message is used by the worker service to send messages to the worker
|
||||
// that are processed by the worker's main loop.
|
||||
message WorkerMessage {
|
||||
@@ -298,3 +293,7 @@ message WorkerMessage {
|
||||
ReusableVar reusable_variable = 3; // A reusable variable to import on the worker.
|
||||
}
|
||||
}
|
||||
|
||||
message PrintErrorMessageRequest {
|
||||
Failure failure = 1; // The failure object.
|
||||
}
|
||||
|
||||
+20
-1
@@ -38,7 +38,8 @@ message PyObj {
|
||||
|
||||
// Used for shipping remote functions to workers
|
||||
message Function {
|
||||
bytes implementation = 1;
|
||||
string name = 1;
|
||||
bytes implementation = 2;
|
||||
}
|
||||
|
||||
message ReusableVar {
|
||||
@@ -47,6 +48,24 @@ message ReusableVar {
|
||||
Function reinitializer = 3; // A serialized version of the function that reinitializes the reusable variable.
|
||||
}
|
||||
|
||||
enum FailedType {
|
||||
FailedTask = 0;
|
||||
FailedRemoteFunctionImport = 1;
|
||||
FailedReusableVariableImport = 2;
|
||||
FailedReinitializeReusableVariable = 3;
|
||||
}
|
||||
|
||||
// Used to represent exceptions thrown in Python. This will happen when a task
|
||||
// fails to execute, a remote function fails to be imported, or a reusable
|
||||
// variable fails to be imported.
|
||||
message Failure {
|
||||
FailedType type = 1; // The type of the failure.
|
||||
uint64 workerid = 2; // The id of the worker on which the failure occurred.
|
||||
string worker_address = 3; // The address of the worker on which the failure occurred. This contains the same information as the workerid.
|
||||
string name = 4; // The name of the failed object.
|
||||
string error_message = 5; // The error message from the failure.
|
||||
}
|
||||
|
||||
// Union of possible object types
|
||||
message Obj {
|
||||
String string_data = 1;
|
||||
|
||||
+72
-18
@@ -715,7 +715,10 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) {
|
||||
PyTuple_SetItem(t, 1, deserialize_task(worker_capsule, message->task()));
|
||||
} else if (function_present) {
|
||||
PyTuple_SetItem(t, 0, PyString_FromString("function"));
|
||||
PyTuple_SetItem(t, 1, PyString_FromStringAndSize(message->function().implementation().data(), static_cast<ssize_t>(message->function().implementation().size())));
|
||||
PyObject* remote_function_data = PyTuple_New(2);
|
||||
PyTuple_SetItem(remote_function_data, 0, PyString_FromStringAndSize(message->function().name().data(), static_cast<ssize_t>(message->function().name().size())));
|
||||
PyTuple_SetItem(remote_function_data, 1, PyString_FromStringAndSize(message->function().implementation().data(), static_cast<ssize_t>(message->function().implementation().size())));
|
||||
PyTuple_SetItem(t, 1, remote_function_data);
|
||||
} else if (reusable_variable_present) {
|
||||
PyTuple_SetItem(t, 0, PyString_FromString("reusable_variable"));
|
||||
PyObject* reusable_variable = PyTuple_New(3);
|
||||
@@ -734,14 +737,15 @@ static PyObject* wait_for_next_message(PyObject* self, PyObject* args) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* export_function(PyObject* self, PyObject* args) {
|
||||
static PyObject* export_remote_function(PyObject* self, PyObject* args) {
|
||||
Worker* worker;
|
||||
const char* function_name;
|
||||
const char* function;
|
||||
int function_size;
|
||||
if (!PyArg_ParseTuple(args, "O&s#", &PyObjectToWorker, &worker, &function, &function_size)) {
|
||||
if (!PyArg_ParseTuple(args, "O&ss#", &PyObjectToWorker, &worker, &function_name, &function, &function_size)) {
|
||||
return NULL;
|
||||
}
|
||||
if (worker->export_function(std::string(function, static_cast<size_t>(function_size)))) {
|
||||
if (worker->export_remote_function(std::string(function_name), std::string(function, static_cast<size_t>(function_size)))) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
@@ -795,25 +799,33 @@ static PyObject* submit_task(PyObject* self, PyObject* args) {
|
||||
|
||||
static PyObject* notify_task_completed(PyObject* self, PyObject* args) {
|
||||
Worker* worker;
|
||||
PyObject* task_succeeded_obj;
|
||||
const char* error_message_ptr;
|
||||
if (!PyArg_ParseTuple(args, "O&Os", &PyObjectToWorker, &worker, &task_succeeded_obj, &error_message_ptr)) {
|
||||
if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) {
|
||||
return NULL;
|
||||
}
|
||||
std::string error_message(error_message_ptr);
|
||||
bool task_succeeded = PyObject_IsTrue(task_succeeded_obj);
|
||||
worker->notify_task_completed(task_succeeded, error_message);
|
||||
worker->notify_task_completed();
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* register_function(PyObject* self, PyObject* args) {
|
||||
static PyObject* register_remote_function(PyObject* self, PyObject* args) {
|
||||
Worker* worker;
|
||||
const char* function_name;
|
||||
int num_return_vals;
|
||||
if (!PyArg_ParseTuple(args, "O&si", &PyObjectToWorker, &worker, &function_name, &num_return_vals)) {
|
||||
return NULL;
|
||||
}
|
||||
worker->register_function(std::string(function_name), num_return_vals);
|
||||
worker->register_remote_function(std::string(function_name), num_return_vals);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* notify_failure(PyObject* self, PyObject* args) {
|
||||
Worker* worker;
|
||||
const char* name;
|
||||
const char* error_message;
|
||||
FailedType type;
|
||||
if (!PyArg_ParseTuple(args, "O&ssi", &PyObjectToWorker, &worker, &name, &error_message, &type)) {
|
||||
return NULL;
|
||||
}
|
||||
worker->notify_failure(type, std::string(name), std::string(error_message));
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
@@ -887,10 +899,11 @@ static PyObject* alias_objectids(PyObject* self, PyObject* args) {
|
||||
|
||||
static PyObject* start_worker_service(PyObject* self, PyObject* args) {
|
||||
Worker* worker;
|
||||
if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) {
|
||||
Mode mode;
|
||||
if (!PyArg_ParseTuple(args, "O&i", &PyObjectToWorker, &worker, &mode)) {
|
||||
return NULL;
|
||||
}
|
||||
worker->start_worker_service();
|
||||
worker->start_worker_service(mode);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
@@ -919,6 +932,15 @@ static PyObject* scheduler_info(PyObject* self, PyObject* args) {
|
||||
return dict;
|
||||
}
|
||||
|
||||
static PyObject* failure_to_dict(const Failure& failure) {
|
||||
PyObject* failure_dict = PyDict_New();
|
||||
set_dict_item_and_transfer_ownership(failure_dict, PyString_FromString("workerid"), PyInt_FromLong(failure.workerid()));
|
||||
set_dict_item_and_transfer_ownership(failure_dict, PyString_FromString("worker_address"), PyString_FromStringAndSize(failure.worker_address().data(), failure.worker_address().size()));
|
||||
set_dict_item_and_transfer_ownership(failure_dict, PyString_FromString("function_name"), PyString_FromStringAndSize(failure.name().data(), failure.name().size()));
|
||||
set_dict_item_and_transfer_ownership(failure_dict, PyString_FromString("error_message"), PyString_FromStringAndSize(failure.error_message().data(), failure.error_message().size()));
|
||||
return failure_dict;
|
||||
}
|
||||
|
||||
static PyObject* task_info(PyObject* self, PyObject* args) {
|
||||
Worker* worker;
|
||||
if (!PyArg_ParseTuple(args, "O&", &PyObjectToWorker, &worker)) {
|
||||
@@ -950,10 +972,27 @@ static PyObject* task_info(PyObject* self, PyObject* args) {
|
||||
PyList_SetItem(running_tasks_list, i, info_dict);
|
||||
}
|
||||
|
||||
PyObject* failed_remote_function_imports = PyList_New(reply.failed_remote_function_import_size());
|
||||
for (size_t i = 0; i < reply.failed_remote_function_import_size(); ++i) {
|
||||
PyList_SetItem(failed_remote_function_imports, i, failure_to_dict(reply.failed_remote_function_import(i)));
|
||||
}
|
||||
|
||||
PyObject* failed_reusable_variable_imports = PyList_New(reply.failed_reusable_variable_import_size());
|
||||
for (size_t i = 0; i < reply.failed_reusable_variable_import_size(); ++i) {
|
||||
PyList_SetItem(failed_reusable_variable_imports, i, failure_to_dict(reply.failed_reusable_variable_import(i)));
|
||||
}
|
||||
|
||||
PyObject* failed_reinitialize_reusable_variables = PyList_New(reply.failed_reinitialize_reusable_variable_size());
|
||||
for (size_t i = 0; i < reply.failed_reinitialize_reusable_variable_size(); ++i) {
|
||||
PyList_SetItem(failed_reinitialize_reusable_variables, i, failure_to_dict(reply.failed_reinitialize_reusable_variable(i)));
|
||||
}
|
||||
|
||||
PyObject* dict = PyDict_New();
|
||||
set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_tasks"), failed_tasks_list);
|
||||
set_dict_item_and_transfer_ownership(dict, PyString_FromString("running_tasks"), running_tasks_list);
|
||||
set_dict_item_and_transfer_ownership(dict, PyString_FromString("num_succeeded"), PyInt_FromLong(reply.num_succeeded()));
|
||||
set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_remote_function_imports"), failed_remote_function_imports);
|
||||
set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_reusable_variable_imports"), failed_reusable_variable_imports);
|
||||
set_dict_item_and_transfer_ownership(dict, PyString_FromString("failed_reinitialize_reusable_variables"), failed_reinitialize_reusable_variables);
|
||||
return dict;
|
||||
}
|
||||
|
||||
@@ -1008,7 +1047,8 @@ static PyMethodDef RayLibMethods[] = {
|
||||
{ "create_worker", create_worker, METH_VARARGS, "connect to the scheduler and the object store" },
|
||||
{ "disconnect", disconnect, METH_VARARGS, "disconnect the worker from the scheduler and the object store" },
|
||||
{ "connected", connected, METH_VARARGS, "check if the worker is connected to the scheduler and the object store" },
|
||||
{ "register_function", register_function, METH_VARARGS, "register a function with the scheduler" },
|
||||
{ "register_remote_function", register_remote_function, METH_VARARGS, "register a function with the scheduler" },
|
||||
{ "notify_failure", notify_failure, METH_VARARGS, "notify the scheduler of a failure" },
|
||||
{ "put_object", put_object, METH_VARARGS, "put a protocol buffer object (given as a capsule) on the local object store" },
|
||||
{ "get_object", get_object, METH_VARARGS, "get protocol buffer object from the local object store" },
|
||||
{ "get_objectid", get_objectid, METH_VARARGS, "register a new object reference with the scheduler" },
|
||||
@@ -1019,8 +1059,8 @@ static PyMethodDef RayLibMethods[] = {
|
||||
{ "notify_task_completed", notify_task_completed, METH_VARARGS, "notify the scheduler that a task has been completed" },
|
||||
{ "start_worker_service", start_worker_service, METH_VARARGS, "start the worker service" },
|
||||
{ "scheduler_info", scheduler_info, METH_VARARGS, "get info about scheduler state" },
|
||||
{ "task_info", task_info, METH_VARARGS, "get task statuses" },
|
||||
{ "export_function", export_function, METH_VARARGS, "export function to workers" },
|
||||
{ "task_info", task_info, METH_VARARGS, "get information about task statuses and failures" },
|
||||
{ "export_remote_function", export_remote_function, METH_VARARGS, "export a remote function to workers" },
|
||||
{ "export_reusable_variable", export_reusable_variable, METH_VARARGS, "export a reusable variable to the workers" },
|
||||
{ "dump_computation_graph", dump_computation_graph, METH_VARARGS, "dump the current computation graph to a file" },
|
||||
{ "set_log_config", set_log_config, METH_VARARGS, "set filename for raylib logging" },
|
||||
@@ -1046,6 +1086,20 @@ PyMODINIT_FUNC initlibraylib(void) {
|
||||
PyModule_AddObject(m, "ray_error", RayError);
|
||||
PyModule_AddObject(m, "ray_size_error", RaySizeError);
|
||||
import_array();
|
||||
|
||||
// Export constants used for the worker mode types so they can be accessed
|
||||
// from Python. The Mode enum is defined in worker.h.
|
||||
PyModule_AddIntConstant(m, "SCRIPT_MODE", Mode::SCRIPT_MODE);
|
||||
PyModule_AddIntConstant(m, "WORKER_MODE", Mode::WORKER_MODE);
|
||||
PyModule_AddIntConstant(m, "PYTHON_MODE", Mode::PYTHON_MODE);
|
||||
PyModule_AddIntConstant(m, "SILENT_MODE", Mode::SILENT_MODE);
|
||||
|
||||
// Export constants for the failure types so they can be accessed from Python.
|
||||
// The FailedType enum is defined in types.proto.
|
||||
PyModule_AddIntConstant(m, "FailedTask", FailedType::FailedTask);
|
||||
PyModule_AddIntConstant(m, "FailedRemoteFunctionImport", FailedType::FailedRemoteFunctionImport);
|
||||
PyModule_AddIntConstant(m, "FailedReusableVariableImport", FailedType::FailedReusableVariableImport);
|
||||
PyModule_AddIntConstant(m, "FailedReinitializeReusableVariable", FailedType::FailedReinitializeReusableVariable);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
+102
-45
@@ -256,7 +256,13 @@ Status SchedulerService::RegisterWorker(ServerContext* context, const RegisterWo
|
||||
{
|
||||
auto workers = GET(workers_);
|
||||
workerid = workers->size();
|
||||
worker_address = node_ip_address + ":" + std::to_string(40000 + workerid);
|
||||
// Generate a random port number. This is currently a hack to avoid reusing
|
||||
// port numbers when we run the tests.
|
||||
std::random_device rd;
|
||||
std::mt19937 rng(rd());
|
||||
std::uniform_int_distribution<int> uni(0, 10000);
|
||||
int port_number = 40000 + uni(rng);
|
||||
worker_address = node_ip_address + ":" + std::to_string(port_number);
|
||||
workers->push_back(WorkerHandle());
|
||||
auto channel = grpc::CreateChannel(worker_address, grpc::InsecureChannelCredentials());
|
||||
(*workers)[workerid].channel = channel;
|
||||
@@ -279,13 +285,64 @@ Status SchedulerService::RegisterWorker(ServerContext* context, const RegisterWo
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status SchedulerService::RegisterFunction(ServerContext* context, const RegisterFunctionRequest* request, AckReply* reply) {
|
||||
RAY_LOG(RAY_INFO, "register function " << request->fnname() << " from workerid " << request->workerid());
|
||||
register_function(request->fnname(), request->workerid(), request->num_return_vals());
|
||||
Status SchedulerService::RegisterRemoteFunction(ServerContext* context, const RegisterRemoteFunctionRequest* request, AckReply* reply) {
|
||||
RAY_LOG(RAY_INFO, "register function " << request->function_name() << " from workerid " << request->workerid());
|
||||
register_function(request->function_name(), request->workerid(), request->num_return_vals());
|
||||
schedule();
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status SchedulerService::NotifyFailure(ServerContext* context, const NotifyFailureRequest* request, AckReply* reply) {
|
||||
const Failure failure = request->failure();
|
||||
WorkerId workerid = failure.workerid();
|
||||
if (failure.type() == FailedType::FailedTask) {
|
||||
// A task threw an exception while executing.
|
||||
TaskStatus failed_task_info;
|
||||
{
|
||||
auto workers = GET(workers_);
|
||||
failed_task_info.set_operationid((*workers)[workerid].current_task);
|
||||
failed_task_info.set_function_name(failure.name());
|
||||
failed_task_info.set_worker_address((*workers)[workerid].worker_address);
|
||||
failed_task_info.set_error_message(failure.error_message());
|
||||
}
|
||||
GET(failed_tasks_)->push_back(failed_task_info);
|
||||
RAY_LOG(RAY_INFO, "Error: Task " << failed_task_info.operationid() << " executing function " << failed_task_info.function_name() << " on worker " << workerid << " failed with error message:\n" << failed_task_info.error_message());
|
||||
} else if (failure.type() == FailedType::FailedRemoteFunctionImport) {
|
||||
// An exception was thrown while a remote function was being imported.
|
||||
GET(failed_remote_function_imports_)->push_back(failure);
|
||||
RAY_LOG(RAY_INFO, "Error: Worker " << workerid << " failed to import remote function " << failure.name() << ", failed with error message:\n" << failure.error_message());
|
||||
} else if (failure.type() == FailedType::FailedReusableVariableImport) {
|
||||
// An exception was thrown while a reusable variable was being imported.
|
||||
GET(failed_reusable_variable_imports_)->push_back(failure);
|
||||
RAY_LOG(RAY_INFO, "Error: Worker " << workerid << " failed to import reusable variable " << failure.name() << ", failed with error message:\n" << failure.error_message());
|
||||
} else if (failure.type() == FailedType::FailedReinitializeReusableVariable) {
|
||||
// An exception was thrown while a reusable variable was being imported.
|
||||
GET(failed_reinitialize_reusable_variables_)->push_back(failure);
|
||||
RAY_LOG(RAY_INFO, "Error: Worker " << workerid << " failed to reinitialize a reusable variable after running remote function " << failure.name() << ", failed with error message:\n" << failure.error_message());
|
||||
} else {
|
||||
RAY_CHECK(false, "This code should be unreachable.")
|
||||
}
|
||||
// Print the failure on the relevant driver. TODO(rkn): At the moment, this
|
||||
// prints the failure on all of the drivers. It should probably only print it
|
||||
// on the driver that caused the problem.
|
||||
auto workers = GET(workers_);
|
||||
for (size_t i = 0; i < workers->size(); ++i) {
|
||||
WorkerHandle* worker = &(*workers)[i];
|
||||
// Check if the worker is still connected.
|
||||
if (worker->worker_stub) {
|
||||
// Check if this is a driver.
|
||||
if (worker->current_task == ROOT_OPERATION) {
|
||||
ClientContext client_context;
|
||||
PrintErrorMessageRequest print_request;
|
||||
print_request.mutable_failure()->CopyFrom(request->failure());
|
||||
AckReply print_reply;
|
||||
Status status = worker->worker_stub->PrintErrorMessage(&client_context, print_request, &print_reply);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status SchedulerService::ObjReady(ServerContext* context, const ObjReadyRequest* request, AckReply* reply) {
|
||||
ObjectID objectid = request->objectid();
|
||||
RAY_LOG(RAY_DEBUG, "object " << objectid << " ready on store " << request->objstoreid());
|
||||
@@ -306,43 +363,25 @@ Status SchedulerService::ObjReady(ServerContext* context, const ObjReadyRequest*
|
||||
|
||||
Status SchedulerService::ReadyForNewTask(ServerContext* context, const ReadyForNewTaskRequest* request, AckReply* reply) {
|
||||
WorkerId workerid = request->workerid();
|
||||
OperationId operationid = (*GET(workers_))[workerid].current_task;
|
||||
RAY_LOG(RAY_INFO, "worker " << workerid << " is ready for a new task");
|
||||
RAY_CHECK(operationid != ROOT_OPERATION, "A driver appears to have called ReadyForNewTask.");
|
||||
{
|
||||
// Check if the worker has been initialized yet, and if not, then give it
|
||||
// all of the exported functions and all of the exported reusable variables.
|
||||
auto workers = GET(workers_);
|
||||
if (!(*workers)[workerid].initialized) {
|
||||
// This should only happen once.
|
||||
// Import all remote functions on the worker.
|
||||
export_all_functions_to_worker(workerid, workers, GET(exported_functions_));
|
||||
// Import all reusable variables on the worker.
|
||||
export_all_reusable_variables_to_worker(workerid, workers, GET(exported_reusable_variables_));
|
||||
// Mark the worker as initialized.
|
||||
(*workers)[workerid].initialized = true;
|
||||
}
|
||||
}
|
||||
if (request->has_previous_task_info()) {
|
||||
RAY_CHECK(operationid != NO_OPERATION, "request->has_previous_task_info() should not be true if operationid == NO_OPERATION.");
|
||||
std::string task_name;
|
||||
task_name = GET(computation_graph_)->get_task(operationid).name();
|
||||
TaskStatus info;
|
||||
OperationId operationid = (*workers)[workerid].current_task;
|
||||
RAY_LOG(RAY_INFO, "worker " << workerid << " is ready for a new task");
|
||||
RAY_CHECK(operationid != ROOT_OPERATION, "A driver appears to have called ReadyForNewTask.");
|
||||
{
|
||||
auto workers = GET(workers_);
|
||||
info.set_operationid(operationid);
|
||||
info.set_function_name(task_name);
|
||||
info.set_worker_address((*workers)[workerid].worker_address);
|
||||
info.set_error_message(request->previous_task_info().error_message());
|
||||
(*workers)[workerid].current_task = NO_OPERATION; // clear operation ID
|
||||
// Check if the worker has been initialized yet, and if not, then give it
|
||||
// all of the exported functions and all of the exported reusable variables.
|
||||
if (!(*workers)[workerid].initialized) {
|
||||
// This should only happen once.
|
||||
// Import all remote functions on the worker.
|
||||
export_all_functions_to_worker(workerid, workers, GET(exported_functions_));
|
||||
// Import all reusable variables on the worker.
|
||||
export_all_reusable_variables_to_worker(workerid, workers, GET(exported_reusable_variables_));
|
||||
// Mark the worker as initialized.
|
||||
(*workers)[workerid].initialized = true;
|
||||
}
|
||||
}
|
||||
if (!request->previous_task_info().task_succeeded()) {
|
||||
RAY_LOG(RAY_INFO, "Error: Task " << info.operationid() << " executing function " << info.function_name() << " on worker " << workerid << " failed with error message:\n" << info.error_message());
|
||||
GET(failed_tasks_)->push_back(info);
|
||||
} else {
|
||||
GET(successful_tasks_)->push_back(info.operationid());
|
||||
}
|
||||
// TODO(rkn): Handle task failure
|
||||
(*workers)[workerid].current_task = NO_OPERATION; // clear operation ID
|
||||
}
|
||||
GET(avail_workers_)->push_back(workerid);
|
||||
schedule();
|
||||
@@ -394,14 +433,18 @@ Status SchedulerService::SchedulerInfo(ServerContext* context, const SchedulerIn
|
||||
}
|
||||
|
||||
Status SchedulerService::TaskInfo(ServerContext* context, const TaskInfoRequest* request, TaskInfoReply* reply) {
|
||||
auto successful_tasks = GET(successful_tasks_);
|
||||
auto failed_tasks = GET(failed_tasks_);
|
||||
auto failed_remote_function_imports = GET(failed_remote_function_imports_);
|
||||
auto failed_reusable_variable_imports = GET(failed_reusable_variable_imports_);
|
||||
auto failed_reinitialize_reusable_variables = GET(failed_reinitialize_reusable_variables_);
|
||||
auto computation_graph = GET(computation_graph_);
|
||||
auto workers = GET(workers_);
|
||||
// Return information about the failed tasks.
|
||||
for (int i = 0; i < failed_tasks->size(); ++i) {
|
||||
TaskStatus* info = reply->add_failed_task();
|
||||
*info = (*failed_tasks)[i];
|
||||
}
|
||||
// Return information about currently running tasks.
|
||||
for (size_t i = 0; i < workers->size(); ++i) {
|
||||
OperationId operationid = (*workers)[i].current_task;
|
||||
if (operationid != NO_OPERATION && operationid != ROOT_OPERATION) {
|
||||
@@ -412,7 +455,21 @@ Status SchedulerService::TaskInfo(ServerContext* context, const TaskInfoRequest*
|
||||
info->set_worker_address((*workers)[i].worker_address);
|
||||
}
|
||||
}
|
||||
reply->set_num_succeeded(successful_tasks->size());
|
||||
// Return information about failed remote function imports.
|
||||
for (size_t i = 0; i < failed_remote_function_imports->size(); ++i) {
|
||||
Failure* failure = reply->add_failed_remote_function_import();
|
||||
*failure = (*failed_remote_function_imports)[i];
|
||||
}
|
||||
// Return information about failed reusable variable imports.
|
||||
for (size_t i = 0; i < failed_reusable_variable_imports->size(); ++i) {
|
||||
Failure* failure = reply->add_failed_reusable_variable_import();
|
||||
*failure = (*failed_reusable_variable_imports)[i];
|
||||
}
|
||||
// Return information about failed reusable variable reinitializations.
|
||||
for (size_t i = 0; i < failed_reinitialize_reusable_variables->size(); ++i) {
|
||||
Failure* failure = reply->add_failed_reinitialize_reusable_variable();
|
||||
*failure = (*failed_reinitialize_reusable_variables)[i];
|
||||
}
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
@@ -449,7 +506,7 @@ Status SchedulerService::KillWorkers(ServerContext* context, const KillWorkersRe
|
||||
for (WorkerHandle* idle_worker : idle_workers) {
|
||||
ClientContext client_context;
|
||||
DieRequest die_request;
|
||||
DieReply die_reply;
|
||||
AckReply die_reply;
|
||||
// TODO: Fault handling... what if a worker refuses to die? We just assume it dies here.
|
||||
idle_worker->worker_stub->Die(&client_context, die_request, &die_reply);
|
||||
idle_worker->worker_stub.reset();
|
||||
@@ -464,7 +521,7 @@ Status SchedulerService::KillWorkers(ServerContext* context, const KillWorkersRe
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status SchedulerService::ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) {
|
||||
Status SchedulerService::ExportRemoteFunction(ServerContext* context, const ExportRemoteFunctionRequest* request, AckReply* reply) {
|
||||
auto workers = GET(workers_);
|
||||
auto exported_functions = GET(exported_functions_);
|
||||
// TODO(rkn): Does this do a deep copy?
|
||||
@@ -556,7 +613,7 @@ void SchedulerService::assign_task(OperationId operationid, WorkerId workerid, c
|
||||
const Task& task = computation_graph->get_task(operationid);
|
||||
ClientContext context;
|
||||
ExecuteTaskRequest request;
|
||||
ExecuteTaskReply reply;
|
||||
AckReply reply;
|
||||
RAY_LOG(RAY_INFO, "starting to send arguments");
|
||||
for (size_t i = 0; i < task.arg_size(); ++i) {
|
||||
if (!task.arg(i).has_obj()) {
|
||||
@@ -970,10 +1027,10 @@ void SchedulerService::get_equivalent_objectids(ObjectID objectid, std::vector<O
|
||||
void SchedulerService::export_function_to_worker(WorkerId workerid, int function_index, MySynchronizedPtr<std::vector<WorkerHandle> > &workers, const MySynchronizedPtr<std::vector<std::unique_ptr<Function> > > &exported_functions) {
|
||||
RAY_LOG(RAY_INFO, "exporting function with index " << function_index << " to worker " << workerid);
|
||||
ClientContext import_context;
|
||||
ImportFunctionRequest import_request;
|
||||
ImportRemoteFunctionRequest import_request;
|
||||
import_request.mutable_function()->CopyFrom(*(*exported_functions)[function_index].get());
|
||||
ImportFunctionReply import_reply;
|
||||
(*workers)[workerid].worker_stub->ImportFunction(&import_context, import_request, &import_reply);
|
||||
AckReply import_reply;
|
||||
(*workers)[workerid].worker_stub->ImportRemoteFunction(&import_context, import_request, &import_reply);
|
||||
}
|
||||
|
||||
void SchedulerService::export_reusable_variable_to_worker(WorkerId workerid, int reusable_variable_index, MySynchronizedPtr<std::vector<WorkerHandle> > &workers, const MySynchronizedPtr<std::vector<std::unique_ptr<ReusableVar> > > &exported_reusable_variables) {
|
||||
|
||||
+9
-4
@@ -65,7 +65,7 @@ public:
|
||||
Status AliasObjectIDs(ServerContext* context, const AliasObjectIDsRequest* request, AckReply* reply) override;
|
||||
Status RegisterObjStore(ServerContext* context, const RegisterObjStoreRequest* request, RegisterObjStoreReply* reply) override;
|
||||
Status RegisterWorker(ServerContext* context, const RegisterWorkerRequest* request, RegisterWorkerReply* reply) override;
|
||||
Status RegisterFunction(ServerContext* context, const RegisterFunctionRequest* request, AckReply* reply) override;
|
||||
Status RegisterRemoteFunction(ServerContext* context, const RegisterRemoteFunctionRequest* request, AckReply* reply) override;
|
||||
Status ObjReady(ServerContext* context, const ObjReadyRequest* request, AckReply* reply) override;
|
||||
Status ReadyForNewTask(ServerContext* context, const ReadyForNewTaskRequest* request, AckReply* reply) override;
|
||||
Status IncrementRefCount(ServerContext* context, const IncrementRefCountRequest* request, AckReply* reply) override;
|
||||
@@ -74,8 +74,9 @@ public:
|
||||
Status SchedulerInfo(ServerContext* context, const SchedulerInfoRequest* request, SchedulerInfoReply* reply) override;
|
||||
Status TaskInfo(ServerContext* context, const TaskInfoRequest* request, TaskInfoReply* reply) override;
|
||||
Status KillWorkers(ServerContext* context, const KillWorkersRequest* request, KillWorkersReply* reply) override;
|
||||
Status ExportFunction(ServerContext* context, const ExportFunctionRequest* request, ExportFunctionReply* reply) override;
|
||||
Status ExportRemoteFunction(ServerContext* context, const ExportRemoteFunctionRequest* request, AckReply* reply) override;
|
||||
Status ExportReusableVariable(ServerContext* context, const ExportReusableVariableRequest* request, AckReply* reply) override;
|
||||
Status NotifyFailure(ServerContext*, const NotifyFailureRequest* request, AckReply* reply) override;
|
||||
|
||||
#ifdef NDEBUG
|
||||
// If we've disabled assertions, then just use regular SynchronizedPtr to skip lock checking.
|
||||
@@ -168,10 +169,14 @@ private:
|
||||
// When we unlock, we subtract back the field offset to restore it to the previous field that was locked.
|
||||
mutable Synchronized<std::vector<std::pair<unsigned long long, std::pair<size_t, const char*> > > > lock_orders_;
|
||||
|
||||
// List of the IDs of successful tasks
|
||||
Synchronized<std::vector<OperationId> > successful_tasks_; // Right now, we only use this information in the TaskInfo call.
|
||||
// List of failed tasks
|
||||
Synchronized<std::vector<TaskStatus> > failed_tasks_;
|
||||
// A list of remote functions import failures.
|
||||
Synchronized<std::vector<Failure> > failed_remote_function_imports_;
|
||||
// A list of reusable variables import failures.
|
||||
Synchronized<std::vector<Failure> > failed_reusable_variable_imports_;
|
||||
// A list of reusable variables reinitialization failures.
|
||||
Synchronized<std::vector<Failure> > failed_reinitialize_reusable_variables_;
|
||||
// List of pending get calls.
|
||||
Synchronized<std::vector<std::pair<WorkerId, ObjectID> > > get_queue_;
|
||||
// The computation graph tracks the operations that have been submitted to the
|
||||
|
||||
+80
-33
@@ -9,12 +9,14 @@ extern "C" {
|
||||
static PyObject *RayError;
|
||||
}
|
||||
|
||||
inline WorkerServiceImpl::WorkerServiceImpl(const std::string& worker_address)
|
||||
: worker_address_(worker_address) {
|
||||
inline WorkerServiceImpl::WorkerServiceImpl(const std::string& worker_address, Mode mode)
|
||||
: worker_address_(worker_address),
|
||||
mode_(mode) {
|
||||
RAY_CHECK(send_queue_.connect(worker_address_, false), "error connecting send_queue_");
|
||||
}
|
||||
|
||||
Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) {
|
||||
Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, AckReply* reply) {
|
||||
RAY_CHECK(mode_ == Mode::WORKER_MODE, "ExecuteTask can only be called on workers.");
|
||||
RAY_LOG(RAY_INFO, "invoked task " << request->task().name());
|
||||
std::unique_ptr<WorkerMessage> message(new WorkerMessage());
|
||||
message->mutable_task()->CopyFrom(request->task());
|
||||
@@ -26,7 +28,8 @@ Status WorkerServiceImpl::ExecuteTask(ServerContext* context, const ExecuteTaskR
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) {
|
||||
Status WorkerServiceImpl::ImportRemoteFunction(ServerContext* context, const ImportRemoteFunctionRequest* request, AckReply* reply) {
|
||||
RAY_CHECK(mode_ == Mode::WORKER_MODE, "ImportRemoteFunction can only be called on workers.");
|
||||
std::unique_ptr<WorkerMessage> message(new WorkerMessage());
|
||||
message->mutable_function()->CopyFrom(request->function());
|
||||
RAY_LOG(RAY_INFO, "importing function");
|
||||
@@ -39,6 +42,7 @@ Status WorkerServiceImpl::ImportFunction(ServerContext* context, const ImportFun
|
||||
}
|
||||
|
||||
Status WorkerServiceImpl::ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) {
|
||||
RAY_CHECK(mode_ == Mode::WORKER_MODE, "ImportReusableVariable can only be called on workers.");
|
||||
std::unique_ptr<WorkerMessage> message(new WorkerMessage());
|
||||
message->mutable_reusable_variable()->CopyFrom(request->reusable_variable());
|
||||
RAY_LOG(RAY_INFO, "importing reusable variable");
|
||||
@@ -50,18 +54,46 @@ Status WorkerServiceImpl::ImportReusableVariable(ServerContext* context, const I
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status WorkerServiceImpl::Die(ServerContext* context, const DieRequest* request, DieReply* reply) {
|
||||
Status WorkerServiceImpl::Die(ServerContext* context, const DieRequest* request, AckReply* reply) {
|
||||
RAY_CHECK(mode_ == Mode::WORKER_MODE, "Die can only be called on workers.");
|
||||
WorkerMessage* message_ptr = NULL;
|
||||
RAY_CHECK(send_queue_.send(&message_ptr), "error sending over IPC");
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Status WorkerServiceImpl::PrintErrorMessage(ServerContext* context, const PrintErrorMessageRequest* request, AckReply* reply) {
|
||||
RAY_CHECK(mode_ != Mode::WORKER_MODE, "PrintErrorMessage can only be called on drivers.");
|
||||
if (mode_ == Mode::SILENT_MODE) {
|
||||
// Do not log error messages in this case. This is just used for the tests.
|
||||
return Status::OK;
|
||||
}
|
||||
const Failure failure = request->failure();
|
||||
WorkerId workerid = failure.workerid();
|
||||
if (failure.type() == FailedType::FailedTask) {
|
||||
// A task threw an exception while executing.
|
||||
std::cout << "Error: Worker " << workerid << " failed to execute function " << failure.name() << ". Failed with error message:\n" << failure.error_message() << std::endl;
|
||||
} else if (failure.type() == FailedType::FailedRemoteFunctionImport) {
|
||||
// An exception was thrown while a remote function was being imported.
|
||||
std::cout << "Error: Worker " << workerid << " failed to import remote function " << failure.name() << ", failed with error message:\n" << failure.error_message() << std::endl;
|
||||
} else if (failure.type() == FailedType::FailedReusableVariableImport) {
|
||||
// An exception was thrown while a reusable variable was being imported.
|
||||
std::cout << "Error: Worker " << workerid << " failed to import reusable variable " << failure.name() << ", failed with error message:\n" << failure.error_message() << std::endl;
|
||||
} else if (failure.type() == FailedType::FailedReinitializeReusableVariable) {
|
||||
// An exception was thrown while a reusable variable was being reinitialized.
|
||||
std::cout << "Error: Worker " << workerid << " failed to reinitialize a reusable variable after running remote function " << failure.name() << ", failed with error message:\n" << failure.error_message() << std::endl;
|
||||
} else {
|
||||
RAY_CHECK(false, "This code should be unreachable.")
|
||||
}
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
Worker::Worker(const std::string& scheduler_address)
|
||||
: scheduler_address_(scheduler_address) {
|
||||
auto scheduler_channel = grpc::CreateChannel(scheduler_address, grpc::InsecureChannelCredentials());
|
||||
scheduler_stub_ = Scheduler::NewStub(scheduler_channel);
|
||||
}
|
||||
|
||||
|
||||
SubmitTaskReply Worker::submit_task(SubmitTaskRequest* request, int max_retries, int retry_wait_milliseconds) {
|
||||
RAY_CHECK(connected_, "Attempted to perform submit_task but failed.");
|
||||
SubmitTaskReply reply;
|
||||
@@ -312,15 +344,28 @@ void Worker::decrement_reference_count(std::vector<ObjectID> &objectids) {
|
||||
}
|
||||
}
|
||||
|
||||
void Worker::register_function(const std::string& name, size_t num_return_vals) {
|
||||
void Worker::register_remote_function(const std::string& name, size_t num_return_vals) {
|
||||
RAY_CHECK(connected_, "Attempted to perform register_function but failed.");
|
||||
ClientContext context;
|
||||
RegisterFunctionRequest request;
|
||||
request.set_fnname(name);
|
||||
request.set_num_return_vals(num_return_vals);
|
||||
RegisterRemoteFunctionRequest request;
|
||||
request.set_workerid(workerid_);
|
||||
request.set_function_name(name);
|
||||
request.set_num_return_vals(num_return_vals);
|
||||
AckReply reply;
|
||||
scheduler_stub_->RegisterFunction(&context, request, &reply);
|
||||
scheduler_stub_->RegisterRemoteFunction(&context, request, &reply);
|
||||
}
|
||||
|
||||
void Worker::notify_failure(FailedType type, const std::string& name, const std::string& error_message) {
|
||||
RAY_CHECK(connected_, "Attempted to perform notify_failure but failed.");
|
||||
ClientContext context;
|
||||
NotifyFailureRequest request;
|
||||
request.mutable_failure()->set_type(type);
|
||||
request.mutable_failure()->set_workerid(workerid_);
|
||||
request.mutable_failure()->set_worker_address(worker_address_);
|
||||
request.mutable_failure()->set_name(name);
|
||||
request.mutable_failure()->set_error_message(error_message);
|
||||
AckReply reply;
|
||||
scheduler_stub_->NotifyFailure(&context, request, &reply);
|
||||
}
|
||||
|
||||
std::unique_ptr<WorkerMessage> Worker::receive_next_message() {
|
||||
@@ -329,24 +374,19 @@ std::unique_ptr<WorkerMessage> Worker::receive_next_message() {
|
||||
return std::unique_ptr<WorkerMessage>(message_ptr);
|
||||
}
|
||||
|
||||
void Worker::notify_task_completed(bool task_succeeded, std::string error_message) {
|
||||
void Worker::notify_task_completed() {
|
||||
RAY_CHECK(connected_, "Attempted to perform notify_task_completed but failed.");
|
||||
ClientContext context;
|
||||
ReadyForNewTaskRequest request;
|
||||
request.set_workerid(workerid_);
|
||||
ReadyForNewTaskRequest::PreviousTaskInfo* previous_task_info = request.mutable_previous_task_info();
|
||||
previous_task_info->set_task_succeeded(task_succeeded);
|
||||
previous_task_info->set_error_message(error_message);
|
||||
AckReply reply;
|
||||
scheduler_stub_->ReadyForNewTask(&context, request, &reply);
|
||||
}
|
||||
|
||||
void Worker::disconnect() {
|
||||
connected_ = false;
|
||||
}
|
||||
|
||||
bool Worker::connected() {
|
||||
return connected_;
|
||||
// TODO(rkn): This probably isn't the right way to clean up the thread.
|
||||
worker_server_thread_->detach();
|
||||
}
|
||||
|
||||
// TODO(rkn): Should we be using pointers or references? And should they be const?
|
||||
@@ -360,13 +400,14 @@ void Worker::task_info(ClientContext &context, TaskInfoRequest &request, TaskInf
|
||||
scheduler_stub_->TaskInfo(&context, request, &reply);
|
||||
}
|
||||
|
||||
bool Worker::export_function(const std::string& function) {
|
||||
bool Worker::export_remote_function(const std::string& function_name, const std::string& function) {
|
||||
RAY_CHECK(connected_, "Attempted to export function but failed.");
|
||||
ClientContext context;
|
||||
ExportFunctionRequest request;
|
||||
ExportRemoteFunctionRequest request;
|
||||
request.mutable_function()->set_name(function_name);
|
||||
request.mutable_function()->set_implementation(function);
|
||||
ExportFunctionReply reply;
|
||||
Status status = scheduler_stub_->ExportFunction(&context, request, &reply);
|
||||
AckReply reply;
|
||||
Status status = scheduler_stub_->ExportRemoteFunction(&context, request, &reply);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -385,26 +426,32 @@ void Worker::export_reusable_variable(const std::string& name, const std::string
|
||||
// queue. This is because the Python interpreter needs to be single threaded
|
||||
// (in our case running in the main thread), whereas the WorkerService will
|
||||
// run in a separate thread and potentially utilize multiple threads.
|
||||
void Worker::start_worker_service() {
|
||||
void Worker::start_worker_service(Mode mode) {
|
||||
const char* service_addr = worker_address_.c_str();
|
||||
worker_server_thread_ = std::thread([this, service_addr]() {
|
||||
// Launch a new thread for running the worker service. We store this as a
|
||||
// field so that we can clean it up when we disconnect the worker.
|
||||
worker_server_thread_ = std::unique_ptr<std::thread>(new std::thread([this, service_addr, mode]() {
|
||||
std::string service_address(service_addr);
|
||||
std::string::iterator split_point = split_ip_address(service_address);
|
||||
std::string port;
|
||||
port.assign(split_point, service_address.end());
|
||||
WorkerServiceImpl service(service_address);
|
||||
// Create the worker service.
|
||||
WorkerServiceImpl service(service_address, mode);
|
||||
ServerBuilder builder;
|
||||
builder.AddListeningPort(std::string("0.0.0.0:") + port, grpc::InsecureServerCredentials());
|
||||
builder.RegisterService(&service);
|
||||
std::unique_ptr<Server> server(builder.BuildAndStart());
|
||||
RAY_LOG(RAY_INFO, "worker server listening on " << service_address);
|
||||
|
||||
ClientContext context;
|
||||
ReadyForNewTaskRequest request;
|
||||
request.set_workerid(workerid_);
|
||||
AckReply reply;
|
||||
scheduler_stub_->ReadyForNewTask(&context, request, &reply);
|
||||
|
||||
// If this is part of a worker process (and not a driver process), then tell
|
||||
// the scheduler that it is ready to start receiving tasks.
|
||||
if (mode == Mode::WORKER_MODE) {
|
||||
ClientContext context;
|
||||
ReadyForNewTaskRequest request;
|
||||
request.set_workerid(workerid_);
|
||||
AckReply reply;
|
||||
scheduler_stub_->ReadyForNewTask(&context, request, &reply);
|
||||
}
|
||||
// Wait for work and process work, this does not return.
|
||||
server->Wait();
|
||||
});
|
||||
}));
|
||||
}
|
||||
|
||||
+27
-15
@@ -23,16 +23,25 @@ using grpc::Channel;
|
||||
using grpc::ClientContext;
|
||||
using grpc::ClientWriter;
|
||||
|
||||
// These three constants are used to define the mode that a worker is running
|
||||
// in. Right now, this is mostly used for determining how to print information
|
||||
// about task failures.
|
||||
enum Mode {SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, SILENT_MODE};
|
||||
|
||||
class WorkerServiceImpl final : public WorkerService::Service {
|
||||
public:
|
||||
WorkerServiceImpl(const std::string& worker_address);
|
||||
Status ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, ExecuteTaskReply* reply) override;
|
||||
Status ImportFunction(ServerContext* context, const ImportFunctionRequest* request, ImportFunctionReply* reply) override;
|
||||
Status Die(ServerContext* context, const DieRequest* request, DieReply* reply) override;
|
||||
WorkerServiceImpl(const std::string& worker_address, Mode mode);
|
||||
Status ExecuteTask(ServerContext* context, const ExecuteTaskRequest* request, AckReply* reply) override;
|
||||
Status ImportRemoteFunction(ServerContext* context, const ImportRemoteFunctionRequest* request, AckReply* reply) override;
|
||||
Status Die(ServerContext* context, const DieRequest* request, AckReply* reply) override;
|
||||
Status ImportReusableVariable(ServerContext* context, const ImportReusableVariableRequest* request, AckReply* reply) override;
|
||||
Status PrintErrorMessage(ServerContext* context, const PrintErrorMessageRequest* request, AckReply* reply) override;
|
||||
private:
|
||||
std::string worker_address_;
|
||||
MessageQueue<WorkerMessage*> send_queue_;
|
||||
// This is true if the worker service is part of a driver process and false
|
||||
// if it is part of a worker process.
|
||||
Mode mode_;
|
||||
};
|
||||
|
||||
class Worker {
|
||||
@@ -71,27 +80,30 @@ class Worker {
|
||||
void increment_reference_count(std::vector<ObjectID> &objectid);
|
||||
// decrement the reference count for objectid
|
||||
void decrement_reference_count(std::vector<ObjectID> &objectid);
|
||||
// register function with scheduler
|
||||
void register_function(const std::string& name, size_t num_return_vals);
|
||||
// start the worker server which accepts tasks from the scheduler and stores
|
||||
// it in the message queue, which is read by the Python interpreter
|
||||
void start_worker_service();
|
||||
// Notify the scheduler that a remote function has been imported successfully.
|
||||
void register_remote_function(const std::string& name, size_t num_return_vals);
|
||||
// Notify the scheduler that a failure has occurred.
|
||||
void notify_failure(FailedType type, const std::string& name, const std::string& error_message);
|
||||
// Start the worker server which accepts commands from the scheduler. For
|
||||
// workers, these commands are stored in the message queue, which is read by
|
||||
// the Python interpreter. For drivers, these commands are only for printing
|
||||
// error messages.
|
||||
void start_worker_service(Mode mode);
|
||||
// wait for next task from the RPC system. If null, it means there are no more tasks and the worker should shut down.
|
||||
std::unique_ptr<WorkerMessage> receive_next_message();
|
||||
// tell the scheduler that we are done with the current task and request the
|
||||
// next one, if task_succeeded is false, this tells the scheduler that the
|
||||
// task threw an exception
|
||||
void notify_task_completed(bool task_succeeded, std::string error_message);
|
||||
// next one.
|
||||
void notify_task_completed();
|
||||
// disconnect the worker
|
||||
void disconnect();
|
||||
// return connected_
|
||||
bool connected();
|
||||
bool connected() { return connected_; }
|
||||
// get info about scheduler state
|
||||
void scheduler_info(ClientContext &context, SchedulerInfoRequest &request, SchedulerInfoReply &reply);
|
||||
// get task statuses from scheduler
|
||||
void task_info(ClientContext &context, TaskInfoRequest &request, TaskInfoReply &reply);
|
||||
// export function to workers
|
||||
bool export_function(const std::string& function);
|
||||
bool export_remote_function(const std::string& function_name, const std::string& function);
|
||||
// export reusable variable to workers
|
||||
void export_reusable_variable(const std::string& name, const std::string& initializer, const std::string& reinitializer);
|
||||
// return the worker address
|
||||
@@ -101,7 +113,7 @@ class Worker {
|
||||
bool connected_;
|
||||
const size_t CHUNK_SIZE = 8 * 1024;
|
||||
std::unique_ptr<Scheduler::Stub> scheduler_stub_;
|
||||
std::thread worker_server_thread_;
|
||||
std::unique_ptr<std::thread> worker_server_thread_;
|
||||
MessageQueue<WorkerMessage*> receive_queue_;
|
||||
bip::managed_shared_memory segment_;
|
||||
WorkerId workerid_;
|
||||
|
||||
+57
-3
@@ -249,14 +249,12 @@ class APITest(unittest.TestCase):
|
||||
task_info = ray.task_info()
|
||||
self.assertEqual(len(task_info["failed_tasks"]), 0)
|
||||
self.assertEqual(len(task_info["running_tasks"]), 0)
|
||||
self.assertEqual(task_info["num_succeeded"], 1)
|
||||
|
||||
test_functions.no_op_fail.remote()
|
||||
time.sleep(0.2)
|
||||
task_info = ray.task_info()
|
||||
self.assertEqual(len(task_info["failed_tasks"]), 1)
|
||||
self.assertEqual(len(task_info["running_tasks"]), 0)
|
||||
self.assertEqual(task_info["num_succeeded"], 1)
|
||||
self.assertTrue("The @remote decorator for function test_functions.no_op_fail has 0 return values, but test_functions.no_op_fail returned more than 0 values." in task_info["failed_tasks"][0].get("error_message"))
|
||||
|
||||
ray.worker.cleanup()
|
||||
@@ -273,7 +271,6 @@ class APITest(unittest.TestCase):
|
||||
task_info = ray.task_info()
|
||||
self.assertEqual(len(task_info["failed_tasks"]), 2)
|
||||
self.assertEqual(len(task_info["running_tasks"]), 0)
|
||||
self.assertEqual(task_info["num_succeeded"], 0)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
@@ -391,6 +388,63 @@ class TaskStatusTest(unittest.TestCase):
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testFailImportingRemoteFunction(self):
|
||||
ray.init(start_ray_local=True, num_workers=2, driver_mode=ray.SILENT_MODE)
|
||||
|
||||
# This example is somewhat contrived. It should be successfully pickled, and
|
||||
# then it should throw an exception when it is unpickled. This may depend a
|
||||
# bit on the specifics of our pickler.
|
||||
def reducer(*args):
|
||||
raise Exception("There is a problem here.")
|
||||
class Foo(object):
|
||||
def __init__(self):
|
||||
self.__name__ = "Foo_object"
|
||||
self.func_doc = ""
|
||||
self.__globals__ = {}
|
||||
def __reduce__(self):
|
||||
return reducer, ()
|
||||
def __call__(self):
|
||||
return
|
||||
ray.remote([], [])(Foo())
|
||||
time.sleep(0.1)
|
||||
self.assertTrue("There is a problem here." in ray.task_info()["failed_remote_function_imports"][0]["error_message"])
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testFailImportingReusableVariable(self):
|
||||
ray.init(start_ray_local=True, num_workers=2, driver_mode=ray.SILENT_MODE)
|
||||
|
||||
# This will throw an exception when the reusable variable is imported on the
|
||||
# workers.
|
||||
def initializer():
|
||||
if ray.worker.global_worker.mode == ray.WORKER_MODE:
|
||||
raise Exception("The initializer failed.")
|
||||
return 0
|
||||
ray.reusables.foo = ray.Reusable(initializer)
|
||||
time.sleep(0.1)
|
||||
# Check that the error message is in the task info.
|
||||
self.assertTrue("The initializer failed." in ray.task_info()["failed_reusable_variable_imports"][0]["error_message"])
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testFailReinitializingVariable(self):
|
||||
ray.init(start_ray_local=True, num_workers=2, driver_mode=ray.SILENT_MODE)
|
||||
|
||||
def initializer():
|
||||
return 0
|
||||
def reinitializer(foo):
|
||||
raise Exception("The reinitializer failed.")
|
||||
ray.reusables.foo = ray.Reusable(initializer, reinitializer)
|
||||
@ray.remote([], [])
|
||||
def use_foo():
|
||||
ray.reusables.foo
|
||||
use_foo.remote()
|
||||
time.sleep(0.1)
|
||||
# Check that the error message is in the task info.
|
||||
self.assertTrue("The reinitializer failed." in ray.task_info()["failed_reinitialize_reusable_variables"][0]["error_message"])
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def check_get_deallocated(data):
|
||||
x = ray.put(data)
|
||||
ray.get(x)
|
||||
|
||||
Reference in New Issue
Block a user