diff --git a/python/ray/actor.py b/python/ray/actor.py index 7fd9fad62..5ce1d7d40 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -6,15 +6,13 @@ import cloudpickle as pickle import hashlib import inspect import json -import numpy as np -import redis import traceback import ray.local_scheduler import ray.signature as signature import ray.worker -from ray.utils import (FunctionProperties, binary_to_hex, hex_to_binary, - random_string) +from ray.utils import (FunctionProperties, random_string, + select_local_scheduler) def random_actor_id(): @@ -102,117 +100,6 @@ def fetch_and_register_actor(actor_class_key, worker): # for the actor. -def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler, worker): - """Attempt to acquire GPUs on a particular local scheduler for an actor. - - Args: - num_gpus: The number of GPUs to acquire. - driver_id: The ID of the driver responsible for creating the actor. - local_scheduler: Information about the local scheduler. - - Returns: - True if the GPUs were successfully reserved and false otherwise. - """ - assert num_gpus != 0 - local_scheduler_id = local_scheduler["DBClientID"] - local_scheduler_total_gpus = int(local_scheduler["NumGPUs"]) - - success = False - - # Attempt to acquire GPU IDs atomically. - with worker.redis_client.pipeline() as pipe: - while True: - try: - # If this key is changed before the transaction below (the - # multi/exec block), then the transaction will not take place. - pipe.watch(local_scheduler_id) - - # Figure out which GPUs are currently in use. - result = worker.redis_client.hget(local_scheduler_id, - "gpus_in_use") - gpus_in_use = dict() if result is None else json.loads( - result.decode("ascii")) - num_gpus_in_use = 0 - for key in gpus_in_use: - num_gpus_in_use += gpus_in_use[key] - assert num_gpus_in_use <= local_scheduler_total_gpus - - pipe.multi() - - if local_scheduler_total_gpus - num_gpus_in_use >= num_gpus: - # There are enough available GPUs, so try to reserve some. - # We use the hex driver ID in hex as a dictionary key so - # that the dictionary is JSON serializable. - driver_id_hex = binary_to_hex(driver_id) - if driver_id_hex not in gpus_in_use: - gpus_in_use[driver_id_hex] = 0 - gpus_in_use[driver_id_hex] += num_gpus - - # Stick the updated GPU IDs back in Redis - pipe.hset(local_scheduler_id, "gpus_in_use", - json.dumps(gpus_in_use)) - success = True - - pipe.execute() - # If a WatchError is not raised, then the operations should - # have gone through atomically. - break - except redis.WatchError: - # Another client must have changed the watched key between the - # time we started WATCHing it and the pipeline's execution. We - # should just retry. - success = False - continue - - return success - - -def select_local_scheduler(local_schedulers, num_gpus, worker): - """Select a local scheduler to assign this actor to. - - Args: - local_schedulers: A list of dictionaries of information about the local - schedulers. - num_gpus (int): The number of GPUs that must be reserved for this - actor. - - Returns: - The ID of the local scheduler that has been chosen. - - Raises: - Exception: An exception is raised if no local scheduler can be found - with sufficient resources. - """ - driver_id = worker.task_driver_id.id() - - local_scheduler_id = None - # Loop through all of the local schedulers in a random order. - local_schedulers = np.random.permutation(local_schedulers) - for local_scheduler in local_schedulers: - if local_scheduler["NumCPUs"] < 1: - continue - if local_scheduler["NumGPUs"] < num_gpus: - continue - if num_gpus == 0: - local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"]) - break - else: - # Try to reserve enough GPUs on this local scheduler. - success = attempt_to_reserve_gpus(num_gpus, driver_id, - local_scheduler, worker) - if success: - local_scheduler_id = hex_to_binary( - local_scheduler["DBClientID"]) - break - - if local_scheduler_id is None: - raise Exception("Could not find a node with enough GPUs or other " - "resources to create this actor. The local scheduler " - "information is {}.".format(local_schedulers)) - - return local_scheduler_id - - def export_actor_class(class_id, Class, actor_method_names, worker): if worker.mode is None: raise NotImplemented("TODO(pcm): Cache actors") @@ -255,17 +142,10 @@ def export_actor(actor_id, class_id, actor_method_names, num_cpus, num_gpus, num_gpus=0, max_calls=0)) - # Get a list of the local schedulers from the client table. - client_table = ray.global_state.client_table() - local_schedulers = [] - for ip_address, clients in client_table.items(): - for client in clients: - if (client["ClientType"] == "local_scheduler" and - not client["Deleted"]): - local_schedulers.append(client) # Select a local scheduler for the actor. - local_scheduler_id = select_local_scheduler(local_schedulers, num_gpus, - worker) + local_scheduler_id = select_local_scheduler( + worker.task_driver_id.id(), ray.global_state.local_schedulers(), + num_gpus, worker.redis_client) assert local_scheduler_id is not None # We must put the actor information in Redis before publishing the actor @@ -274,17 +154,12 @@ def export_actor(actor_id, class_id, actor_method_names, num_cpus, num_gpus, worker.redis_client.hmset(key, {"class_id": class_id, "num_gpus": num_gpus}) - # Really we should encode this message as a flatbuffer object. However, - # we're having trouble getting that to work. It almost works, but in Python - # 2.7, builder.CreateString fails on byte strings that contain characters - # outside range(128). - # TODO(rkn): There is actually no guarantee that the local scheduler that # we are publishing to has already subscribed to the actor_notifications # channel. Therefore, this message may be missed and the workload will # hang. This is a bug. - worker.redis_client.publish("actor_notifications", - actor_id.id() + driver_id + local_scheduler_id) + ray.utils.publish_actor_creation(actor_id.id(), driver_id, + local_scheduler_id, worker.redis_client) def actor(*args, **kwargs): @@ -319,8 +194,7 @@ def make_actor(cls, num_cpus, num_gpus): args = signature.extend_args(function_signature, args, kwargs) function_id = get_actor_method_function_id(attr) - object_ids = ray.worker.global_worker.submit_task(function_id, "", - args, + object_ids = ray.worker.global_worker.submit_task(function_id, args, actor_id=actor_id) if len(object_ids) == 1: return object_ids[0] diff --git a/python/ray/experimental/state.py b/python/ray/experimental/state.py index 71ce78cbc..d3c260a3e 100644 --- a/python/ray/experimental/state.py +++ b/python/ray/experimental/state.py @@ -642,6 +642,21 @@ class GlobalState(object): all_times.append(data["store_outputs_end"]) return all_times + def local_schedulers(self): + """Get a list of live local schedulers. + + Returns: + A list of the live local schedulers. + """ + clients = self.client_table() + local_schedulers = [] + for ip_address, client_list in clients.items(): + for client in client_list: + if (client["ClientType"] == "local_scheduler" and + not client["Deleted"]): + local_schedulers.append(client) + return local_schedulers + def workers(self): """Get a dictionary mapping worker ID to worker information.""" worker_keys = self.redis_client.keys("Worker*") @@ -666,6 +681,22 @@ class GlobalState(object): } return workers_data + def actors(self): + actor_keys = self.redis_client.keys("Actor:*") + actor_info = dict() + for key in actor_keys: + info = self.redis_client.hgetall(key) + actor_id = key[len("Actor:"):] + assert len(actor_id) == 20 + actor_info[binary_to_hex(actor_id)] = { + "class_id": binary_to_hex(info[b"class_id"]), + "driver_id": binary_to_hex(info[b"driver_id"]), + "local_scheduler_id": + binary_to_hex(info[b"local_scheduler_id"]), + "num_gpus": int(info[b"num_gpus"]), + "removed": decode(info[b"removed"]) == "True"} + return actor_info + def _job_length(self): event_log_sets = self.redis_client.keys("event_log*") overall_smallest = sys.maxsize diff --git a/python/ray/utils.py b/python/ray/utils.py index 2f6ed1423..b92c0f6b4 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -4,7 +4,9 @@ from __future__ import print_function import binascii import collections +import json import numpy as np +import redis import sys import ray.local_scheduler @@ -65,3 +67,141 @@ FunctionProperties = collections.namedtuple("FunctionProperties", "num_gpus", "max_calls"]) """FunctionProperties: A named tuple storing remote functions information.""" + + +def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler, + redis_client): + """Attempt to acquire GPUs on a particular local scheduler for an actor. + + Args: + num_gpus: The number of GPUs to acquire. + driver_id: The ID of the driver responsible for creating the actor. + local_scheduler: Information about the local scheduler. + redis_client: The redis client to use for interacting with Redis. + + Returns: + True if the GPUs were successfully reserved and false otherwise. + """ + assert num_gpus != 0 + local_scheduler_id = local_scheduler["DBClientID"] + local_scheduler_total_gpus = int(local_scheduler["NumGPUs"]) + + success = False + + # Attempt to acquire GPU IDs atomically. + with redis_client.pipeline() as pipe: + while True: + try: + # If this key is changed before the transaction below (the + # multi/exec block), then the transaction will not take place. + pipe.watch(local_scheduler_id) + + # Figure out which GPUs are currently in use. + result = redis_client.hget(local_scheduler_id, "gpus_in_use") + gpus_in_use = dict() if result is None else json.loads( + result.decode("ascii")) + num_gpus_in_use = 0 + for key in gpus_in_use: + num_gpus_in_use += gpus_in_use[key] + assert num_gpus_in_use <= local_scheduler_total_gpus + + pipe.multi() + + if local_scheduler_total_gpus - num_gpus_in_use >= num_gpus: + # There are enough available GPUs, so try to reserve some. + # We use the hex driver ID in hex as a dictionary key so + # that the dictionary is JSON serializable. + driver_id_hex = binary_to_hex(driver_id) + if driver_id_hex not in gpus_in_use: + gpus_in_use[driver_id_hex] = 0 + gpus_in_use[driver_id_hex] += num_gpus + + # Stick the updated GPU IDs back in Redis + pipe.hset(local_scheduler_id, "gpus_in_use", + json.dumps(gpus_in_use)) + success = True + + pipe.execute() + # If a WatchError is not raised, then the operations should + # have gone through atomically. + break + except redis.WatchError: + # Another client must have changed the watched key between the + # time we started WATCHing it and the pipeline's execution. We + # should just retry. + success = False + continue + + return success + + +def select_local_scheduler(driver_id, local_schedulers, num_gpus, + redis_client): + """Select a local scheduler to assign this actor to. + + Args: + driver_id: The ID of the driver who the actor is for. + local_schedulers: A list of dictionaries of information about the local + schedulers. + num_gpus (int): The number of GPUs that must be reserved for this + actor. + redis_client: The Redis client to use for interacting with Redis. + + Returns: + The ID of the local scheduler that has been chosen. + + Raises: + Exception: An exception is raised if no local scheduler can be found + with sufficient resources. + """ + local_scheduler_id = None + # Loop through all of the local schedulers in a random order. + local_schedulers = np.random.permutation(local_schedulers) + for local_scheduler in local_schedulers: + if local_scheduler["NumCPUs"] < 1: + continue + if local_scheduler["NumGPUs"] < num_gpus: + continue + if num_gpus == 0: + local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"]) + break + else: + # Try to reserve enough GPUs on this local scheduler. + success = attempt_to_reserve_gpus(num_gpus, driver_id, + local_scheduler, redis_client) + if success: + local_scheduler_id = hex_to_binary( + local_scheduler["DBClientID"]) + break + + if local_scheduler_id is None: + raise Exception("Could not find a node with enough GPUs or other " + "resources to create this actor. The local scheduler " + "information is {}.".format(local_schedulers)) + + return local_scheduler_id + + +def publish_actor_creation(actor_id, driver_id, local_scheduler_id, + redis_client): + """Publish a notification that an actor should be created. + + This broadcast will be received by all of the local schedulers. The local + scheduler whose ID is being broadcast will create the actor. Any other + local schedulers that have already created the actor will kill it. All + local schedulers will update their internal data structures to redirect + tasks for this actor to the new local scheduler. + + Args: + actor_id: The ID of the actor involved. + driver_id: The ID of the driver responsible for the actor. + local_scheduler_id: The ID of the local scheduler that is suposed to + create the actor. + redis_client: The client used to interact with Redis. + """ + # Really we should encode this message as a flatbuffer object. However, + # we're having trouble getting that to work. It almost works, but in Python + # 2.7, builder.CreateString fails on byte strings that contain characters + # outside range(128). + redis_client.publish("actor_notifications", + actor_id + driver_id + local_scheduler_id) diff --git a/python/ray/worker.py b/python/ray/worker.py index 52ebcd234..fa956ac30 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -473,15 +473,14 @@ class Worker(object): assert final_results[i][0] == object_ids[i].id() return [result[1][0] for result in final_results] - def submit_task(self, function_id, func_name, args, actor_id=None): + def submit_task(self, function_id, args, actor_id=None): """Submit a remote task to the scheduler. - Tell the scheduler to schedule the execution of the function with name - func_name with arguments args. Retrieve object IDs for the outputs of + Tell the scheduler to schedule the execution of the function with ID + function_id with arguments args. Retrieve object IDs for the outputs of the function from the scheduler and immediately return them. Args: - func_name (str): The name of the function to be executed. args (List[Any]): The arguments to pass into the function. Arguments can be object IDs or they can be values. If they are values, they must be serializable objecs. @@ -513,7 +512,8 @@ class Worker(object): function_properties.num_return_vals, self.current_task_id, self.task_index, - actor_id, self.actor_counters[actor_id], + actor_id, + self.actor_counters[actor_id], [function_properties.num_cpus, function_properties.num_gpus]) # Increment the worker's task index to track how many tasks have # been submitted by the current task so far. @@ -582,6 +582,260 @@ class Worker(object): "data": data}) self.redis_client.rpush("ErrorKeys", error_key) + def _wait_for_function(self, function_id, driver_id, timeout=10): + """Wait until the function to be executed is present on this worker. + + This method will simply loop until the import thread has imported the + relevant function. If we spend too long in this loop, that may indicate + a problem somewhere and we will push an error message to the user. + + If this worker is an actor, then this will wait until the actor has + been defined. + + Args: + is_actor (bool): True if this worker is an actor, and false + otherwise. + function_id (str): The ID of the function that we want to execute. + driver_id (str): The ID of the driver to push the error message to + if this times out. + """ + start_time = time.time() + # Only send the warning once. + warning_sent = False + while True: + with self.lock: + if (self.actor_id == NIL_ACTOR_ID and + (function_id.id() in self.functions[driver_id])): + break + elif self.actor_id != NIL_ACTOR_ID and (self.actor_id in + self.actors): + break + if time.time() - start_time > timeout: + warning_message = ("This worker was asked to execute a " + "function that it does not have " + "registered. You may have to restart " + "Ray.") + if not warning_sent: + self.push_error_to_driver(driver_id, + "wait_for_function", + warning_message) + warning_sent = True + time.sleep(0.001) + + def _get_arguments_for_execution(self, function_name, serialized_args): + """Retrieve the arguments for the remote function. + + This retrieves the values for the arguments to the remote function that + were passed in as object IDs. Argumens that were passed by value are + not changed. This is called by the worker that is executing the remote + function. + + Args: + function_name (str): The name of the remote function whose + arguments are being retrieved. + serialized_args (List): The arguments to the function. These are + either strings representing serialized objects passed by value + or they are ObjectIDs. + + Returns: + The retrieved arguments in addition to the arguments that were + passed by value. + + Raises: + RayGetArgumentError: This exception is raised if a task that + created one of the arguments failed. + """ + arguments = [] + for (i, arg) in enumerate(serialized_args): + if isinstance(arg, ray.local_scheduler.ObjectID): + # get the object from the local object store + argument = self.get_object([arg])[0] + if isinstance(argument, RayTaskError): + # If the result is a RayTaskError, then the task that + # created this object failed, and we should propagate the + # error message here. + raise RayGetArgumentError(function_name, i, arg, argument) + else: + # pass the argument by value + argument = arg + + arguments.append(argument) + return arguments + + def _store_outputs_in_objstore(self, objectids, outputs): + """Store the outputs of a remote function in the local object store. + + This stores the values that were returned by a remote function in the + local object store. If any of the return values are object IDs, then + these object IDs are aliased with the object IDs that the scheduler + assigned for the return values. This is called by the worker that + executes the remote function. + + Note: + The arguments objectids and outputs should have the same length. + + Args: + objectids (List[ObjectID]): The object IDs that were assigned to + the outputs of the remote function call. + outputs (Tuple): The value returned by the remote function. If the + remote function was supposed to only return one value, then its + output was wrapped in a tuple with one element prior to being + passed into this function. + """ + for i in range(len(objectids)): + self.put_object(objectids[i], outputs[i]) + + def _process_task(self, task): + """Execute a task assigned to this worker. + + This method deserializes a task from the scheduler, and attempts to + execute the task. If the task succeeds, the outputs are stored in the + local object store. If the task throws an exception, RayTaskError + objects are stored in the object store to represent the failed task + (these will be retrieved by calls to get or by subsequent tasks that + use the outputs of this task). + """ + try: + # The ID of the driver that this task belongs to. This is needed so + # that if the task throws an exception, we propagate the error + # message to the correct driver. + self.task_driver_id = task.driver_id() + self.current_task_id = task.task_id() + self.current_function_id = task.function_id().id() + self.task_index = 0 + self.put_index = 0 + function_id = task.function_id() + args = task.arguments() + return_object_ids = task.returns() + function_name, function_executor = (self.functions + [self.task_driver_id.id()] + [function_id.id()]) + + # Get task arguments from the object store. + with log_span("ray:task:get_arguments", worker=self): + arguments = self._get_arguments_for_execution(function_name, + args) + + # Execute the task. + with log_span("ray:task:execute", worker=self): + if task.actor_id().id() == NIL_ACTOR_ID: + outputs = function_executor.executor(arguments) + else: + outputs = function_executor( + self.actors[task.actor_id().id()], *arguments) + + # Store the outputs in the local object store. + with log_span("ray:task:store_outputs", worker=self): + if len(return_object_ids) == 1: + outputs = (outputs,) + self._store_outputs_in_objstore(return_object_ids, outputs) + except Exception as e: + # We determine whether the exception was caused by the call to + # _get_arguments_for_execution or by the execution of the remote + # function or by the call to _store_outputs_in_objstore. Depending + # on which case occurred, we format the error message differently. + # whether the variables "arguments" and "outputs" are defined. + if "arguments" in locals() and "outputs" not in locals(): + if task.actor_id().id() == NIL_ACTOR_ID: + # The error occurred during the task execution. + traceback_str = format_error_message( + traceback.format_exc(), task_exception=True) + else: + # The error occurred during the execution of an actor task. + traceback_str = format_error_message( + traceback.format_exc()) + elif "arguments" in locals() and "outputs" in locals(): + # The error occurred after the task executed. + traceback_str = format_error_message(traceback.format_exc()) + else: + # The error occurred before the task execution. + if (isinstance(e, RayGetError) or + isinstance(e, RayGetArgumentError)): + # In this case, getting the task arguments failed. + traceback_str = None + else: + traceback_str = traceback.format_exc() + failure_object = RayTaskError(function_name, e, traceback_str) + failure_objects = [failure_object for _ + in range(len(return_object_ids))] + self._store_outputs_in_objstore(return_object_ids, failure_objects) + # Log the error message. + self.push_error_to_driver(self.task_driver_id.id(), "task", + str(failure_object), + data={"function_id": function_id.id(), + "function_name": function_name}) + + def _wait_for_and_process_task(self, task): + """Wait for a task to be ready and process the task. + + Args: + task: The task to execute. + """ + function_id = task.function_id() + # Wait until the function to be executed has actually been registered + # on this worker. We will push warnings to the user if we spend too + # long in this loop. + with log_span("ray:wait_for_function", worker=self): + self._wait_for_function(function_id, task.driver_id().id()) + + # Execute the task. + # TODO(rkn): Consider acquiring this lock with a timeout and pushing a + # warning to the user if we are waiting too long to acquire the lock + # because that may indicate that the system is hanging, and it'd be + # good to know where the system is hanging. + log(event_type="ray:acquire_lock", kind=LOG_SPAN_START, worker=self) + with self.lock: + log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, + worker=self) + + function_name, _ = (self.functions[task.driver_id().id()] + [function_id.id()]) + contents = {"function_name": function_name, + "task_id": task.task_id().hex(), + "worker_id": binary_to_hex(self.worker_id)} + with log_span("ray:task", contents=contents, worker=self): + self._process_task(task) + + # Push all of the log events to the global state store. + flush_log() + + # Increase the task execution counter. + (self.num_task_executions[task.driver_id().id()] + [function_id.id()]) += 1 + + reached_max_executions = ( + self.num_task_executions[task.driver_id().id()] + [function_id.id()] == + self.function_properties[task.driver_id().id()] + [function_id.id()].max_calls) + if reached_max_executions: + ray.worker.global_worker.local_scheduler_client.disconnect() + os._exit(0) + + def _get_next_task_from_local_scheduler(self): + """Get the next task from the local scheduler. + + Returns: + A task from the local scheduler. + """ + with log_span("ray:get_task", worker=self): + task = self.local_scheduler_client.get_task() + return task + + def main_loop(self): + """The main loop a worker runs to receive and execute tasks.""" + + def exit(signum, frame): + cleanup(worker=self) + sys.exit(0) + + signal.signal(signal.SIGTERM, exit) + + check_main_thread() + while True: + task = self._get_next_task_from_local_scheduler() + self._wait_for_and_process_task(task) + def get_gpu_ids(): """Get the IDs of the GPU that are available to the worker. @@ -1731,45 +1985,6 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker): return ready_ids, remaining_ids -def wait_for_function(function_id, driver_id, timeout=10, - worker=global_worker): - """Wait until the function to be executed is present on this worker. - - This method will simply loop until the import thread has imported the - relevant function. If we spend too long in this loop, that may indicate a - problem somewhere and we will push an error message to the user. - - If this worker is an actor, then this will wait until the actor has been - defined. - - Args: - is_actor (bool): True if this worker is an actor, and false otherwise. - function_id (str): The ID of the function that we want to execute. - driver_id (str): The ID of the driver to push the error message to if - this times out. - """ - start_time = time.time() - # Only send the warning once. - warning_sent = False - while True: - with worker.lock: - if (worker.actor_id == NIL_ACTOR_ID and - (function_id.id() in worker.functions[driver_id])): - break - elif worker.actor_id != NIL_ACTOR_ID and (worker.actor_id in - worker.actors): - break - if time.time() - start_time > timeout: - warning_message = ("This worker was asked to execute a " - "function that it does not have " - "registered. You may have to restart Ray.") - if not warning_sent: - worker.push_error_to_driver(driver_id, "wait_for_function", - warning_message) - warning_sent = True - time.sleep(0.001) - - def format_error_message(exception_message, task_exception=False): """Improve the formatting of an exception thrown by a remote function. @@ -1792,145 +2007,7 @@ def format_error_message(exception_message, task_exception=False): return "\n".join(lines) -def main_loop(worker=global_worker): - """The main loop a worker runs to receive and execute tasks.""" - - def exit(signum, frame): - cleanup(worker=worker) - sys.exit(0) - - signal.signal(signal.SIGTERM, exit) - - def process_task(task): - """Execute a task assigned to this worker. - - This method deserializes a task from the scheduler, and attempts to - execute the task. If the task succeeds, the outputs are stored in the - local object store. If the task throws an exception, RayTaskError - objects are stored in the object store to represent the failed task - (these will be retrieved by calls to get or by subsequent tasks that - use the outputs of this task). - """ - try: - # The ID of the driver that this task belongs to. This is needed so - # that if the task throws an exception, we propagate the error - # message to the correct driver. - worker.task_driver_id = task.driver_id() - worker.current_task_id = task.task_id() - worker.current_function_id = task.function_id().id() - worker.task_index = 0 - worker.put_index = 0 - function_id = task.function_id() - args = task.arguments() - return_object_ids = task.returns() - function_name, function_executor = (worker.functions - [worker.task_driver_id.id()] - [function_id.id()]) - - # Get task arguments from the object store. - with log_span("ray:task:get_arguments", worker=worker): - arguments = get_arguments_for_execution(function_name, args, - worker) - - # Execute the task. - with log_span("ray:task:execute", worker=worker): - if task.actor_id().id() == NIL_ACTOR_ID: - outputs = function_executor.executor(arguments) - else: - outputs = function_executor( - worker.actors[task.actor_id().id()], *arguments) - - # Store the outputs in the local object store. - with log_span("ray:task:store_outputs", worker=worker): - if len(return_object_ids) == 1: - outputs = (outputs,) - store_outputs_in_objstore(return_object_ids, outputs, worker) - except Exception as e: - # We determine whether the exception was caused by the call to - # get_arguments_for_execution or by the execution of the remote - # function or by the call to store_outputs_in_objstore. Depending - # on which case occurred, we format the error message differently. - # whether the variables "arguments" and "outputs" are defined. - if "arguments" in locals() and "outputs" not in locals(): - if task.actor_id().id() == NIL_ACTOR_ID: - # The error occurred during the task execution. - traceback_str = format_error_message( - traceback.format_exc(), task_exception=True) - else: - # The error occurred during the execution of an actor task. - traceback_str = format_error_message( - traceback.format_exc()) - elif "arguments" in locals() and "outputs" in locals(): - # The error occurred after the task executed. - traceback_str = format_error_message(traceback.format_exc()) - else: - # The error occurred before the task execution. - if (isinstance(e, RayGetError) or - isinstance(e, RayGetArgumentError)): - # In this case, getting the task arguments failed. - traceback_str = None - else: - traceback_str = traceback.format_exc() - failure_object = RayTaskError(function_name, e, traceback_str) - failure_objects = [failure_object for _ - in range(len(return_object_ids))] - store_outputs_in_objstore(return_object_ids, failure_objects, - worker) - # Log the error message. - worker.push_error_to_driver(worker.task_driver_id.id(), "task", - str(failure_object), - data={"function_id": function_id.id(), - "function_name": function_name}) - - check_main_thread() - while True: - with log_span("ray:get_task", worker=worker): - task = worker.local_scheduler_client.get_task() - - function_id = task.function_id() - # Wait until the function to be executed has actually been registered - # on this worker. We will push warnings to the user if we spend too - # long in this loop. - with log_span("ray:wait_for_function", worker=worker): - wait_for_function(function_id, task.driver_id().id(), - worker=worker) - - # Execute the task. - # TODO(rkn): Consider acquiring this lock with a timeout and pushing a - # warning to the user if we are waiting too long to acquire the lock - # because that may indicate that the system is hanging, and it'd be - # good to know where the system is hanging. - log(event_type="ray:acquire_lock", kind=LOG_SPAN_START, worker=worker) - with worker.lock: - log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, - worker=worker) - - function_name, _ = (worker.functions[task.driver_id().id()] - [function_id.id()]) - contents = {"function_name": function_name, - "task_id": task.task_id().hex(), - "worker_id": binary_to_hex(worker.worker_id)} - with log_span("ray:task", contents=contents, worker=worker): - process_task(task) - - # Push all of the log events to the global state store. - flush_log() - - # Increase the task execution counter. - (worker.num_task_executions[task.driver_id().id()] - [function_id.id()]) += 1 - - reached_max_executions = ( - worker.num_task_executions[task.driver_id().id()] - [function_id.id()] == - worker.function_properties[task.driver_id().id()] - [function_id.id()].max_calls) - if reached_max_executions: - ray.worker.global_worker.local_scheduler_client.disconnect() - os._exit(0) - - -def _submit_task(function_id, func_name, args, worker=global_worker): +def _submit_task(function_id, args, worker=global_worker): """This is a wrapper around worker.submit_task. We use this wrapper so that in the remote decorator, we can call @@ -1938,7 +2015,7 @@ def _submit_task(function_id, func_name, args, worker=global_worker): attempt to serialize remote functions, we don't attempt to serialize the worker object, which cannot be serialized. """ - return worker.submit_task(function_id, func_name, args) + return worker.submit_task(function_id, args) def _mode(worker=global_worker): @@ -2081,7 +2158,7 @@ def remote(*args, **kwargs): # immutable remote objects. result = func(*copy.deepcopy(args)) return result - objectids = _submit_task(function_id, func_name, args) + objectids = _submit_task(function_id, args) if len(objectids) == 1: return objectids[0] elif len(objectids) > 1: @@ -2157,69 +2234,3 @@ def remote(*args, **kwargs): assert "function_id" not in kwargs return make_remote_decorator(num_return_vals, num_cpus, num_gpus, max_calls) - - -def get_arguments_for_execution(function_name, serialized_args, - worker=global_worker): - """Retrieve the arguments for the remote function. - - This retrieves the values for the arguments to the remote function that - were passed in as object IDs. Argumens that were passed by value are not - changed. This is called by the worker that is executing the remote - function. - - Args: - function_name (str): The name of the remote function whose arguments - are being retrieved. - serialized_args (List): The arguments to the function. These are either - strings representing serialized objects passed by value or they are - ObjectIDs. - - Returns: - The retrieved arguments in addition to the arguments that were passed - by value. - - Raises: - RayGetArgumentError: This exception is raised if a task that created - one of the arguments failed. - """ - arguments = [] - for (i, arg) in enumerate(serialized_args): - if isinstance(arg, ray.local_scheduler.ObjectID): - # get the object from the local object store - argument = worker.get_object([arg])[0] - if isinstance(argument, RayTaskError): - # If the result is a RayTaskError, then the task that created - # this object failed, and we should propagate the error message - # here. - raise RayGetArgumentError(function_name, i, arg, argument) - else: - # pass the argument by value - argument = arg - - arguments.append(argument) - return arguments - - -def store_outputs_in_objstore(objectids, outputs, worker=global_worker): - """Store the outputs of a remote function in the local object store. - - This stores the values that were returned by a remote function in the local - object store. If any of the return values are object IDs, then these object - IDs are aliased with the object IDs that the scheduler assigned for the - return values. This is called by the worker that executes the remote - function. - - Note: - The arguments objectids and outputs should have the same length. - - Args: - objectids (List[ObjectID]): The object IDs that were assigned to the - outputs of the remote function call. - outputs (Tuple): The value returned by the remote function. If the - remote function was supposed to only return one value, then its - output was wrapped in a tuple with one element prior to being - passed into this function. - """ - for i in range(len(objectids)): - worker.put_object(objectids[i], outputs[i]) diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index e037f50c8..63ba3069b 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -30,6 +30,31 @@ def random_string(): return np.random.bytes(20) +def create_redis_client(redis_address): + redis_ip_address, redis_port = redis_address.split(":") + # For this command to work, some other client (on the same machine + # as Redis) must have run "CONFIG SET protected-mode no". + return redis.StrictRedis(host=redis_ip_address, port=int(redis_port)) + + +def push_error_to_all_drivers(redis_client, message): + """Push an error message to all drivers. + + Args: + redis_client: The redis client to use. + message: The error message to push. + """ + DRIVER_ID_LENGTH = 20 + # We use a driver ID of all zeros to push an error message to all + # drivers. + driver_id = DRIVER_ID_LENGTH * b"\x00" + error_key = b"Error:" + driver_id + b":" + random_string() + # Create a Redis client. + redis_client.hmset(error_key, {"type": "worker_crash", + "message": message}) + redis_client.rpush("ErrorKeys", error_key) + + if __name__ == "__main__": args = parser.parse_args() info = {"node_ip_address": args.node_ip_address, @@ -57,25 +82,12 @@ if __name__ == "__main__": # task) should be caught and handled inside of the call to # main_loop. If an exception is thrown here, then that means that # there is some error that we didn't anticipate. - ray.worker.main_loop() + ray.worker.global_worker.main_loop() except Exception as e: traceback_str = traceback.format_exc() + error_explanation - DRIVER_ID_LENGTH = 20 - # We use a driver ID of all zeros to push an error message to all - # drivers. - driver_id = DRIVER_ID_LENGTH * b"\x00" - error_key = b"Error:" + driver_id + b":" + random_string() - redis_ip_address, redis_port = args.redis_address.split(":") - # For this command to work, some other client (on the same machine - # as Redis) must have run "CONFIG SET protected-mode no". - redis_client = redis.StrictRedis(host=redis_ip_address, - port=int(redis_port)) - redis_client.hmset(error_key, {"type": "worker_crash", - "message": traceback_str, - "note": ("This error is unexpected " - "and should not have " - "happened.")}) - redis_client.rpush("ErrorKeys", error_key) + # Create a Redis client. + redis_client = create_redis_client(args.redis_address) + push_error_to_all_drivers(redis_client, traceback_str) # TODO(rkn): Note that if the worker was in the middle of executing # a task, then any worker or driver that is blocking in a get call # and waiting for the output of that task will hang. We need to diff --git a/src/common/lib/python/common_extension.cc b/src/common/lib/python/common_extension.cc index fca10f130..d4ef3f9d9 100644 --- a/src/common/lib/python/common_extension.cc +++ b/src/common/lib/python/common_extension.cc @@ -347,6 +347,11 @@ static PyObject *PyTask_actor_id(PyObject *self) { return PyObjectID_make(actor_id); } +static PyObject *PyTask_actor_counter(PyObject *self) { + int64_t actor_counter = TaskSpec_actor_counter(((PyTask *) self)->spec); + return PyLong_FromLongLong(actor_counter); +} + static PyObject *PyTask_driver_id(PyObject *self) { UniqueID driver_id = TaskSpec_driver_id(((PyTask *) self)->spec); return PyObjectID_make(driver_id); @@ -357,6 +362,16 @@ static PyObject *PyTask_task_id(PyObject *self) { return PyObjectID_make(task_id); } +static PyObject *PyTask_parent_task_id(PyObject *self) { + TaskID task_id = TaskSpec_parent_task_id(((PyTask *) self)->spec); + return PyObjectID_make(task_id); +} + +static PyObject *PyTask_parent_counter(PyObject *self) { + int64_t parent_counter = TaskSpec_parent_counter(((PyTask *) self)->spec); + return PyLong_FromLongLong(parent_counter); +} + static PyObject *PyTask_arguments(PyObject *self) { TaskSpec *task = ((PyTask *) self)->spec; int64_t num_args = TaskSpec_num_args(task); @@ -404,8 +419,14 @@ static PyObject *PyTask_returns(PyObject *self) { static PyMethodDef PyTask_methods[] = { {"function_id", (PyCFunction) PyTask_function_id, METH_NOARGS, "Return the function ID for this task."}, + {"parent_task_id", (PyCFunction) PyTask_parent_task_id, METH_NOARGS, + "Return the task ID of the parent task."}, + {"parent_counter", (PyCFunction) PyTask_parent_counter, METH_NOARGS, + "Return the parent counter of this task."}, {"actor_id", (PyCFunction) PyTask_actor_id, METH_NOARGS, "Return the actor ID for this task."}, + {"actor_counter", (PyCFunction) PyTask_actor_counter, METH_NOARGS, + "Return the actor counter for this task."}, {"driver_id", (PyCFunction) PyTask_driver_id, METH_NOARGS, "Return the driver ID for this task."}, {"task_id", (PyCFunction) PyTask_task_id, METH_NOARGS, diff --git a/src/common/task.cc b/src/common/task.cc index 7dafa9154..f2fd3e1bd 100644 --- a/src/common/task.cc +++ b/src/common/task.cc @@ -226,6 +226,18 @@ UniqueID TaskSpec_driver_id(TaskSpec *spec) { return from_flatbuf(message->driver_id()); } +TaskID TaskSpec_parent_task_id(TaskSpec *spec) { + CHECK(spec); + auto message = flatbuffers::GetRoot(spec); + return from_flatbuf(message->parent_task_id()); +} + +int64_t TaskSpec_parent_counter(TaskSpec *spec) { + CHECK(spec); + auto message = flatbuffers::GetRoot(spec); + return message->parent_counter(); +} + int64_t TaskSpec_num_args(TaskSpec *spec) { CHECK(spec); auto message = flatbuffers::GetRoot(spec); diff --git a/src/common/task.h b/src/common/task.h index 557921462..13f4faa65 100644 --- a/src/common/task.h +++ b/src/common/task.h @@ -143,6 +143,23 @@ int64_t TaskSpec_actor_counter(TaskSpec *spec); */ UniqueID TaskSpec_driver_id(TaskSpec *spec); +/** + * Return the task ID of the parent task. + * + * @param spec The task_spec in question. + * @return The task ID of the parent task. + */ +TaskID TaskSpec_parent_task_id(TaskSpec *spec); + +/** + * Return the task counter of the parent task. For example, this equals 5 if + * this task was the 6th task submitted by the parent task. + * + * @param spec The task_spec in question. + * @return The task counter of the parent task. + */ +int64_t TaskSpec_parent_counter(TaskSpec *spec); + /** * Return the task ID of the task. *