Move worker methods into Worker class and expose more TaskSpec fields to Python. (#796)

* Move worker methods inside worker class. Move some helper methods from actor.py into utils.py and state.py.

* Add more methods exposing task spec fields to Python.

* Fix linting.

* Fix error.

* Remove unused code in default worker.
This commit is contained in:
Robert Nishihara
2017-08-01 17:16:57 -07:00
committed by Philipp Moritz
parent 52a27be364
commit 8c8258de20
8 changed files with 520 additions and 402 deletions
+8 -134
View File
@@ -6,15 +6,13 @@ import cloudpickle as pickle
import hashlib
import inspect
import json
import numpy as np
import redis
import traceback
import ray.local_scheduler
import ray.signature as signature
import ray.worker
from ray.utils import (FunctionProperties, binary_to_hex, hex_to_binary,
random_string)
from ray.utils import (FunctionProperties, random_string,
select_local_scheduler)
def random_actor_id():
@@ -102,117 +100,6 @@ def fetch_and_register_actor(actor_class_key, worker):
# for the actor.
def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler, worker):
"""Attempt to acquire GPUs on a particular local scheduler for an actor.
Args:
num_gpus: The number of GPUs to acquire.
driver_id: The ID of the driver responsible for creating the actor.
local_scheduler: Information about the local scheduler.
Returns:
True if the GPUs were successfully reserved and false otherwise.
"""
assert num_gpus != 0
local_scheduler_id = local_scheduler["DBClientID"]
local_scheduler_total_gpus = int(local_scheduler["NumGPUs"])
success = False
# Attempt to acquire GPU IDs atomically.
with worker.redis_client.pipeline() as pipe:
while True:
try:
# If this key is changed before the transaction below (the
# multi/exec block), then the transaction will not take place.
pipe.watch(local_scheduler_id)
# Figure out which GPUs are currently in use.
result = worker.redis_client.hget(local_scheduler_id,
"gpus_in_use")
gpus_in_use = dict() if result is None else json.loads(
result.decode("ascii"))
num_gpus_in_use = 0
for key in gpus_in_use:
num_gpus_in_use += gpus_in_use[key]
assert num_gpus_in_use <= local_scheduler_total_gpus
pipe.multi()
if local_scheduler_total_gpus - num_gpus_in_use >= num_gpus:
# There are enough available GPUs, so try to reserve some.
# We use the hex driver ID in hex as a dictionary key so
# that the dictionary is JSON serializable.
driver_id_hex = binary_to_hex(driver_id)
if driver_id_hex not in gpus_in_use:
gpus_in_use[driver_id_hex] = 0
gpus_in_use[driver_id_hex] += num_gpus
# Stick the updated GPU IDs back in Redis
pipe.hset(local_scheduler_id, "gpus_in_use",
json.dumps(gpus_in_use))
success = True
pipe.execute()
# If a WatchError is not raised, then the operations should
# have gone through atomically.
break
except redis.WatchError:
# Another client must have changed the watched key between the
# time we started WATCHing it and the pipeline's execution. We
# should just retry.
success = False
continue
return success
def select_local_scheduler(local_schedulers, num_gpus, worker):
"""Select a local scheduler to assign this actor to.
Args:
local_schedulers: A list of dictionaries of information about the local
schedulers.
num_gpus (int): The number of GPUs that must be reserved for this
actor.
Returns:
The ID of the local scheduler that has been chosen.
Raises:
Exception: An exception is raised if no local scheduler can be found
with sufficient resources.
"""
driver_id = worker.task_driver_id.id()
local_scheduler_id = None
# Loop through all of the local schedulers in a random order.
local_schedulers = np.random.permutation(local_schedulers)
for local_scheduler in local_schedulers:
if local_scheduler["NumCPUs"] < 1:
continue
if local_scheduler["NumGPUs"] < num_gpus:
continue
if num_gpus == 0:
local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"])
break
else:
# Try to reserve enough GPUs on this local scheduler.
success = attempt_to_reserve_gpus(num_gpus, driver_id,
local_scheduler, worker)
if success:
local_scheduler_id = hex_to_binary(
local_scheduler["DBClientID"])
break
if local_scheduler_id is None:
raise Exception("Could not find a node with enough GPUs or other "
"resources to create this actor. The local scheduler "
"information is {}.".format(local_schedulers))
return local_scheduler_id
def export_actor_class(class_id, Class, actor_method_names, worker):
if worker.mode is None:
raise NotImplemented("TODO(pcm): Cache actors")
@@ -255,17 +142,10 @@ def export_actor(actor_id, class_id, actor_method_names, num_cpus, num_gpus,
num_gpus=0,
max_calls=0))
# Get a list of the local schedulers from the client table.
client_table = ray.global_state.client_table()
local_schedulers = []
for ip_address, clients in client_table.items():
for client in clients:
if (client["ClientType"] == "local_scheduler" and
not client["Deleted"]):
local_schedulers.append(client)
# Select a local scheduler for the actor.
local_scheduler_id = select_local_scheduler(local_schedulers, num_gpus,
worker)
local_scheduler_id = select_local_scheduler(
worker.task_driver_id.id(), ray.global_state.local_schedulers(),
num_gpus, worker.redis_client)
assert local_scheduler_id is not None
# We must put the actor information in Redis before publishing the actor
@@ -274,17 +154,12 @@ def export_actor(actor_id, class_id, actor_method_names, num_cpus, num_gpus,
worker.redis_client.hmset(key, {"class_id": class_id,
"num_gpus": num_gpus})
# Really we should encode this message as a flatbuffer object. However,
# we're having trouble getting that to work. It almost works, but in Python
# 2.7, builder.CreateString fails on byte strings that contain characters
# outside range(128).
# TODO(rkn): There is actually no guarantee that the local scheduler that
# we are publishing to has already subscribed to the actor_notifications
# channel. Therefore, this message may be missed and the workload will
# hang. This is a bug.
worker.redis_client.publish("actor_notifications",
actor_id.id() + driver_id + local_scheduler_id)
ray.utils.publish_actor_creation(actor_id.id(), driver_id,
local_scheduler_id, worker.redis_client)
def actor(*args, **kwargs):
@@ -319,8 +194,7 @@ def make_actor(cls, num_cpus, num_gpus):
args = signature.extend_args(function_signature, args, kwargs)
function_id = get_actor_method_function_id(attr)
object_ids = ray.worker.global_worker.submit_task(function_id, "",
args,
object_ids = ray.worker.global_worker.submit_task(function_id, args,
actor_id=actor_id)
if len(object_ids) == 1:
return object_ids[0]
+31
View File
@@ -642,6 +642,21 @@ class GlobalState(object):
all_times.append(data["store_outputs_end"])
return all_times
def local_schedulers(self):
"""Get a list of live local schedulers.
Returns:
A list of the live local schedulers.
"""
clients = self.client_table()
local_schedulers = []
for ip_address, client_list in clients.items():
for client in client_list:
if (client["ClientType"] == "local_scheduler" and
not client["Deleted"]):
local_schedulers.append(client)
return local_schedulers
def workers(self):
"""Get a dictionary mapping worker ID to worker information."""
worker_keys = self.redis_client.keys("Worker*")
@@ -666,6 +681,22 @@ class GlobalState(object):
}
return workers_data
def actors(self):
actor_keys = self.redis_client.keys("Actor:*")
actor_info = dict()
for key in actor_keys:
info = self.redis_client.hgetall(key)
actor_id = key[len("Actor:"):]
assert len(actor_id) == 20
actor_info[binary_to_hex(actor_id)] = {
"class_id": binary_to_hex(info[b"class_id"]),
"driver_id": binary_to_hex(info[b"driver_id"]),
"local_scheduler_id":
binary_to_hex(info[b"local_scheduler_id"]),
"num_gpus": int(info[b"num_gpus"]),
"removed": decode(info[b"removed"]) == "True"}
return actor_info
def _job_length(self):
event_log_sets = self.redis_client.keys("event_log*")
overall_smallest = sys.maxsize
+140
View File
@@ -4,7 +4,9 @@ from __future__ import print_function
import binascii
import collections
import json
import numpy as np
import redis
import sys
import ray.local_scheduler
@@ -65,3 +67,141 @@ FunctionProperties = collections.namedtuple("FunctionProperties",
"num_gpus",
"max_calls"])
"""FunctionProperties: A named tuple storing remote functions information."""
def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler,
redis_client):
"""Attempt to acquire GPUs on a particular local scheduler for an actor.
Args:
num_gpus: The number of GPUs to acquire.
driver_id: The ID of the driver responsible for creating the actor.
local_scheduler: Information about the local scheduler.
redis_client: The redis client to use for interacting with Redis.
Returns:
True if the GPUs were successfully reserved and false otherwise.
"""
assert num_gpus != 0
local_scheduler_id = local_scheduler["DBClientID"]
local_scheduler_total_gpus = int(local_scheduler["NumGPUs"])
success = False
# Attempt to acquire GPU IDs atomically.
with redis_client.pipeline() as pipe:
while True:
try:
# If this key is changed before the transaction below (the
# multi/exec block), then the transaction will not take place.
pipe.watch(local_scheduler_id)
# Figure out which GPUs are currently in use.
result = redis_client.hget(local_scheduler_id, "gpus_in_use")
gpus_in_use = dict() if result is None else json.loads(
result.decode("ascii"))
num_gpus_in_use = 0
for key in gpus_in_use:
num_gpus_in_use += gpus_in_use[key]
assert num_gpus_in_use <= local_scheduler_total_gpus
pipe.multi()
if local_scheduler_total_gpus - num_gpus_in_use >= num_gpus:
# There are enough available GPUs, so try to reserve some.
# We use the hex driver ID in hex as a dictionary key so
# that the dictionary is JSON serializable.
driver_id_hex = binary_to_hex(driver_id)
if driver_id_hex not in gpus_in_use:
gpus_in_use[driver_id_hex] = 0
gpus_in_use[driver_id_hex] += num_gpus
# Stick the updated GPU IDs back in Redis
pipe.hset(local_scheduler_id, "gpus_in_use",
json.dumps(gpus_in_use))
success = True
pipe.execute()
# If a WatchError is not raised, then the operations should
# have gone through atomically.
break
except redis.WatchError:
# Another client must have changed the watched key between the
# time we started WATCHing it and the pipeline's execution. We
# should just retry.
success = False
continue
return success
def select_local_scheduler(driver_id, local_schedulers, num_gpus,
redis_client):
"""Select a local scheduler to assign this actor to.
Args:
driver_id: The ID of the driver who the actor is for.
local_schedulers: A list of dictionaries of information about the local
schedulers.
num_gpus (int): The number of GPUs that must be reserved for this
actor.
redis_client: The Redis client to use for interacting with Redis.
Returns:
The ID of the local scheduler that has been chosen.
Raises:
Exception: An exception is raised if no local scheduler can be found
with sufficient resources.
"""
local_scheduler_id = None
# Loop through all of the local schedulers in a random order.
local_schedulers = np.random.permutation(local_schedulers)
for local_scheduler in local_schedulers:
if local_scheduler["NumCPUs"] < 1:
continue
if local_scheduler["NumGPUs"] < num_gpus:
continue
if num_gpus == 0:
local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"])
break
else:
# Try to reserve enough GPUs on this local scheduler.
success = attempt_to_reserve_gpus(num_gpus, driver_id,
local_scheduler, redis_client)
if success:
local_scheduler_id = hex_to_binary(
local_scheduler["DBClientID"])
break
if local_scheduler_id is None:
raise Exception("Could not find a node with enough GPUs or other "
"resources to create this actor. The local scheduler "
"information is {}.".format(local_schedulers))
return local_scheduler_id
def publish_actor_creation(actor_id, driver_id, local_scheduler_id,
redis_client):
"""Publish a notification that an actor should be created.
This broadcast will be received by all of the local schedulers. The local
scheduler whose ID is being broadcast will create the actor. Any other
local schedulers that have already created the actor will kill it. All
local schedulers will update their internal data structures to redirect
tasks for this actor to the new local scheduler.
Args:
actor_id: The ID of the actor involved.
driver_id: The ID of the driver responsible for the actor.
local_scheduler_id: The ID of the local scheduler that is suposed to
create the actor.
redis_client: The client used to interact with Redis.
"""
# Really we should encode this message as a flatbuffer object. However,
# we're having trouble getting that to work. It almost works, but in Python
# 2.7, builder.CreateString fails on byte strings that contain characters
# outside range(128).
redis_client.publish("actor_notifications",
actor_id + driver_id + local_scheduler_id)
+262 -251
View File
@@ -473,15 +473,14 @@ class Worker(object):
assert final_results[i][0] == object_ids[i].id()
return [result[1][0] for result in final_results]
def submit_task(self, function_id, func_name, args, actor_id=None):
def submit_task(self, function_id, args, actor_id=None):
"""Submit a remote task to the scheduler.
Tell the scheduler to schedule the execution of the function with name
func_name with arguments args. Retrieve object IDs for the outputs of
Tell the scheduler to schedule the execution of the function with ID
function_id with arguments args. Retrieve object IDs for the outputs of
the function from the scheduler and immediately return them.
Args:
func_name (str): The name of the function to be executed.
args (List[Any]): The arguments to pass into the function.
Arguments can be object IDs or they can be values. If they are
values, they must be serializable objecs.
@@ -513,7 +512,8 @@ class Worker(object):
function_properties.num_return_vals,
self.current_task_id,
self.task_index,
actor_id, self.actor_counters[actor_id],
actor_id,
self.actor_counters[actor_id],
[function_properties.num_cpus, function_properties.num_gpus])
# Increment the worker's task index to track how many tasks have
# been submitted by the current task so far.
@@ -582,6 +582,260 @@ class Worker(object):
"data": data})
self.redis_client.rpush("ErrorKeys", error_key)
def _wait_for_function(self, function_id, driver_id, timeout=10):
"""Wait until the function to be executed is present on this worker.
This method will simply loop until the import thread has imported the
relevant function. If we spend too long in this loop, that may indicate
a problem somewhere and we will push an error message to the user.
If this worker is an actor, then this will wait until the actor has
been defined.
Args:
is_actor (bool): True if this worker is an actor, and false
otherwise.
function_id (str): The ID of the function that we want to execute.
driver_id (str): The ID of the driver to push the error message to
if this times out.
"""
start_time = time.time()
# Only send the warning once.
warning_sent = False
while True:
with self.lock:
if (self.actor_id == NIL_ACTOR_ID and
(function_id.id() in self.functions[driver_id])):
break
elif self.actor_id != NIL_ACTOR_ID and (self.actor_id in
self.actors):
break
if time.time() - start_time > timeout:
warning_message = ("This worker was asked to execute a "
"function that it does not have "
"registered. You may have to restart "
"Ray.")
if not warning_sent:
self.push_error_to_driver(driver_id,
"wait_for_function",
warning_message)
warning_sent = True
time.sleep(0.001)
def _get_arguments_for_execution(self, function_name, serialized_args):
"""Retrieve the arguments for the remote function.
This retrieves the values for the arguments to the remote function that
were passed in as object IDs. Argumens that were passed by value are
not changed. This is called by the worker that is executing the remote
function.
Args:
function_name (str): The name of the remote function whose
arguments are being retrieved.
serialized_args (List): The arguments to the function. These are
either strings representing serialized objects passed by value
or they are ObjectIDs.
Returns:
The retrieved arguments in addition to the arguments that were
passed by value.
Raises:
RayGetArgumentError: This exception is raised if a task that
created one of the arguments failed.
"""
arguments = []
for (i, arg) in enumerate(serialized_args):
if isinstance(arg, ray.local_scheduler.ObjectID):
# get the object from the local object store
argument = self.get_object([arg])[0]
if isinstance(argument, RayTaskError):
# If the result is a RayTaskError, then the task that
# created this object failed, and we should propagate the
# error message here.
raise RayGetArgumentError(function_name, i, arg, argument)
else:
# pass the argument by value
argument = arg
arguments.append(argument)
return arguments
def _store_outputs_in_objstore(self, objectids, outputs):
"""Store the outputs of a remote function in the local object store.
This stores the values that were returned by a remote function in the
local object store. If any of the return values are object IDs, then
these object IDs are aliased with the object IDs that the scheduler
assigned for the return values. This is called by the worker that
executes the remote function.
Note:
The arguments objectids and outputs should have the same length.
Args:
objectids (List[ObjectID]): The object IDs that were assigned to
the outputs of the remote function call.
outputs (Tuple): The value returned by the remote function. If the
remote function was supposed to only return one value, then its
output was wrapped in a tuple with one element prior to being
passed into this function.
"""
for i in range(len(objectids)):
self.put_object(objectids[i], outputs[i])
def _process_task(self, task):
"""Execute a task assigned to this worker.
This method deserializes a task from the scheduler, and attempts to
execute the task. If the task succeeds, the outputs are stored in the
local object store. If the task throws an exception, RayTaskError
objects are stored in the object store to represent the failed task
(these will be retrieved by calls to get or by subsequent tasks that
use the outputs of this task).
"""
try:
# The ID of the driver that this task belongs to. This is needed so
# that if the task throws an exception, we propagate the error
# message to the correct driver.
self.task_driver_id = task.driver_id()
self.current_task_id = task.task_id()
self.current_function_id = task.function_id().id()
self.task_index = 0
self.put_index = 0
function_id = task.function_id()
args = task.arguments()
return_object_ids = task.returns()
function_name, function_executor = (self.functions
[self.task_driver_id.id()]
[function_id.id()])
# Get task arguments from the object store.
with log_span("ray:task:get_arguments", worker=self):
arguments = self._get_arguments_for_execution(function_name,
args)
# Execute the task.
with log_span("ray:task:execute", worker=self):
if task.actor_id().id() == NIL_ACTOR_ID:
outputs = function_executor.executor(arguments)
else:
outputs = function_executor(
self.actors[task.actor_id().id()], *arguments)
# Store the outputs in the local object store.
with log_span("ray:task:store_outputs", worker=self):
if len(return_object_ids) == 1:
outputs = (outputs,)
self._store_outputs_in_objstore(return_object_ids, outputs)
except Exception as e:
# We determine whether the exception was caused by the call to
# _get_arguments_for_execution or by the execution of the remote
# function or by the call to _store_outputs_in_objstore. Depending
# on which case occurred, we format the error message differently.
# whether the variables "arguments" and "outputs" are defined.
if "arguments" in locals() and "outputs" not in locals():
if task.actor_id().id() == NIL_ACTOR_ID:
# The error occurred during the task execution.
traceback_str = format_error_message(
traceback.format_exc(), task_exception=True)
else:
# The error occurred during the execution of an actor task.
traceback_str = format_error_message(
traceback.format_exc())
elif "arguments" in locals() and "outputs" in locals():
# The error occurred after the task executed.
traceback_str = format_error_message(traceback.format_exc())
else:
# The error occurred before the task execution.
if (isinstance(e, RayGetError) or
isinstance(e, RayGetArgumentError)):
# In this case, getting the task arguments failed.
traceback_str = None
else:
traceback_str = traceback.format_exc()
failure_object = RayTaskError(function_name, e, traceback_str)
failure_objects = [failure_object for _
in range(len(return_object_ids))]
self._store_outputs_in_objstore(return_object_ids, failure_objects)
# Log the error message.
self.push_error_to_driver(self.task_driver_id.id(), "task",
str(failure_object),
data={"function_id": function_id.id(),
"function_name": function_name})
def _wait_for_and_process_task(self, task):
"""Wait for a task to be ready and process the task.
Args:
task: The task to execute.
"""
function_id = task.function_id()
# Wait until the function to be executed has actually been registered
# on this worker. We will push warnings to the user if we spend too
# long in this loop.
with log_span("ray:wait_for_function", worker=self):
self._wait_for_function(function_id, task.driver_id().id())
# Execute the task.
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
# warning to the user if we are waiting too long to acquire the lock
# because that may indicate that the system is hanging, and it'd be
# good to know where the system is hanging.
log(event_type="ray:acquire_lock", kind=LOG_SPAN_START, worker=self)
with self.lock:
log(event_type="ray:acquire_lock", kind=LOG_SPAN_END,
worker=self)
function_name, _ = (self.functions[task.driver_id().id()]
[function_id.id()])
contents = {"function_name": function_name,
"task_id": task.task_id().hex(),
"worker_id": binary_to_hex(self.worker_id)}
with log_span("ray:task", contents=contents, worker=self):
self._process_task(task)
# Push all of the log events to the global state store.
flush_log()
# Increase the task execution counter.
(self.num_task_executions[task.driver_id().id()]
[function_id.id()]) += 1
reached_max_executions = (
self.num_task_executions[task.driver_id().id()]
[function_id.id()] ==
self.function_properties[task.driver_id().id()]
[function_id.id()].max_calls)
if reached_max_executions:
ray.worker.global_worker.local_scheduler_client.disconnect()
os._exit(0)
def _get_next_task_from_local_scheduler(self):
"""Get the next task from the local scheduler.
Returns:
A task from the local scheduler.
"""
with log_span("ray:get_task", worker=self):
task = self.local_scheduler_client.get_task()
return task
def main_loop(self):
"""The main loop a worker runs to receive and execute tasks."""
def exit(signum, frame):
cleanup(worker=self)
sys.exit(0)
signal.signal(signal.SIGTERM, exit)
check_main_thread()
while True:
task = self._get_next_task_from_local_scheduler()
self._wait_for_and_process_task(task)
def get_gpu_ids():
"""Get the IDs of the GPU that are available to the worker.
@@ -1731,45 +1985,6 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
return ready_ids, remaining_ids
def wait_for_function(function_id, driver_id, timeout=10,
worker=global_worker):
"""Wait until the function to be executed is present on this worker.
This method will simply loop until the import thread has imported the
relevant function. If we spend too long in this loop, that may indicate a
problem somewhere and we will push an error message to the user.
If this worker is an actor, then this will wait until the actor has been
defined.
Args:
is_actor (bool): True if this worker is an actor, and false otherwise.
function_id (str): The ID of the function that we want to execute.
driver_id (str): The ID of the driver to push the error message to if
this times out.
"""
start_time = time.time()
# Only send the warning once.
warning_sent = False
while True:
with worker.lock:
if (worker.actor_id == NIL_ACTOR_ID and
(function_id.id() in worker.functions[driver_id])):
break
elif worker.actor_id != NIL_ACTOR_ID and (worker.actor_id in
worker.actors):
break
if time.time() - start_time > timeout:
warning_message = ("This worker was asked to execute a "
"function that it does not have "
"registered. You may have to restart Ray.")
if not warning_sent:
worker.push_error_to_driver(driver_id, "wait_for_function",
warning_message)
warning_sent = True
time.sleep(0.001)
def format_error_message(exception_message, task_exception=False):
"""Improve the formatting of an exception thrown by a remote function.
@@ -1792,145 +2007,7 @@ def format_error_message(exception_message, task_exception=False):
return "\n".join(lines)
def main_loop(worker=global_worker):
"""The main loop a worker runs to receive and execute tasks."""
def exit(signum, frame):
cleanup(worker=worker)
sys.exit(0)
signal.signal(signal.SIGTERM, exit)
def process_task(task):
"""Execute a task assigned to this worker.
This method deserializes a task from the scheduler, and attempts to
execute the task. If the task succeeds, the outputs are stored in the
local object store. If the task throws an exception, RayTaskError
objects are stored in the object store to represent the failed task
(these will be retrieved by calls to get or by subsequent tasks that
use the outputs of this task).
"""
try:
# The ID of the driver that this task belongs to. This is needed so
# that if the task throws an exception, we propagate the error
# message to the correct driver.
worker.task_driver_id = task.driver_id()
worker.current_task_id = task.task_id()
worker.current_function_id = task.function_id().id()
worker.task_index = 0
worker.put_index = 0
function_id = task.function_id()
args = task.arguments()
return_object_ids = task.returns()
function_name, function_executor = (worker.functions
[worker.task_driver_id.id()]
[function_id.id()])
# Get task arguments from the object store.
with log_span("ray:task:get_arguments", worker=worker):
arguments = get_arguments_for_execution(function_name, args,
worker)
# Execute the task.
with log_span("ray:task:execute", worker=worker):
if task.actor_id().id() == NIL_ACTOR_ID:
outputs = function_executor.executor(arguments)
else:
outputs = function_executor(
worker.actors[task.actor_id().id()], *arguments)
# Store the outputs in the local object store.
with log_span("ray:task:store_outputs", worker=worker):
if len(return_object_ids) == 1:
outputs = (outputs,)
store_outputs_in_objstore(return_object_ids, outputs, worker)
except Exception as e:
# We determine whether the exception was caused by the call to
# get_arguments_for_execution or by the execution of the remote
# function or by the call to store_outputs_in_objstore. Depending
# on which case occurred, we format the error message differently.
# whether the variables "arguments" and "outputs" are defined.
if "arguments" in locals() and "outputs" not in locals():
if task.actor_id().id() == NIL_ACTOR_ID:
# The error occurred during the task execution.
traceback_str = format_error_message(
traceback.format_exc(), task_exception=True)
else:
# The error occurred during the execution of an actor task.
traceback_str = format_error_message(
traceback.format_exc())
elif "arguments" in locals() and "outputs" in locals():
# The error occurred after the task executed.
traceback_str = format_error_message(traceback.format_exc())
else:
# The error occurred before the task execution.
if (isinstance(e, RayGetError) or
isinstance(e, RayGetArgumentError)):
# In this case, getting the task arguments failed.
traceback_str = None
else:
traceback_str = traceback.format_exc()
failure_object = RayTaskError(function_name, e, traceback_str)
failure_objects = [failure_object for _
in range(len(return_object_ids))]
store_outputs_in_objstore(return_object_ids, failure_objects,
worker)
# Log the error message.
worker.push_error_to_driver(worker.task_driver_id.id(), "task",
str(failure_object),
data={"function_id": function_id.id(),
"function_name": function_name})
check_main_thread()
while True:
with log_span("ray:get_task", worker=worker):
task = worker.local_scheduler_client.get_task()
function_id = task.function_id()
# Wait until the function to be executed has actually been registered
# on this worker. We will push warnings to the user if we spend too
# long in this loop.
with log_span("ray:wait_for_function", worker=worker):
wait_for_function(function_id, task.driver_id().id(),
worker=worker)
# Execute the task.
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
# warning to the user if we are waiting too long to acquire the lock
# because that may indicate that the system is hanging, and it'd be
# good to know where the system is hanging.
log(event_type="ray:acquire_lock", kind=LOG_SPAN_START, worker=worker)
with worker.lock:
log(event_type="ray:acquire_lock", kind=LOG_SPAN_END,
worker=worker)
function_name, _ = (worker.functions[task.driver_id().id()]
[function_id.id()])
contents = {"function_name": function_name,
"task_id": task.task_id().hex(),
"worker_id": binary_to_hex(worker.worker_id)}
with log_span("ray:task", contents=contents, worker=worker):
process_task(task)
# Push all of the log events to the global state store.
flush_log()
# Increase the task execution counter.
(worker.num_task_executions[task.driver_id().id()]
[function_id.id()]) += 1
reached_max_executions = (
worker.num_task_executions[task.driver_id().id()]
[function_id.id()] ==
worker.function_properties[task.driver_id().id()]
[function_id.id()].max_calls)
if reached_max_executions:
ray.worker.global_worker.local_scheduler_client.disconnect()
os._exit(0)
def _submit_task(function_id, func_name, args, worker=global_worker):
def _submit_task(function_id, args, worker=global_worker):
"""This is a wrapper around worker.submit_task.
We use this wrapper so that in the remote decorator, we can call
@@ -1938,7 +2015,7 @@ def _submit_task(function_id, func_name, args, worker=global_worker):
attempt to serialize remote functions, we don't attempt to serialize the
worker object, which cannot be serialized.
"""
return worker.submit_task(function_id, func_name, args)
return worker.submit_task(function_id, args)
def _mode(worker=global_worker):
@@ -2081,7 +2158,7 @@ def remote(*args, **kwargs):
# immutable remote objects.
result = func(*copy.deepcopy(args))
return result
objectids = _submit_task(function_id, func_name, args)
objectids = _submit_task(function_id, args)
if len(objectids) == 1:
return objectids[0]
elif len(objectids) > 1:
@@ -2157,69 +2234,3 @@ def remote(*args, **kwargs):
assert "function_id" not in kwargs
return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
max_calls)
def get_arguments_for_execution(function_name, serialized_args,
worker=global_worker):
"""Retrieve the arguments for the remote function.
This retrieves the values for the arguments to the remote function that
were passed in as object IDs. Argumens that were passed by value are not
changed. This is called by the worker that is executing the remote
function.
Args:
function_name (str): The name of the remote function whose arguments
are being retrieved.
serialized_args (List): The arguments to the function. These are either
strings representing serialized objects passed by value or they are
ObjectIDs.
Returns:
The retrieved arguments in addition to the arguments that were passed
by value.
Raises:
RayGetArgumentError: This exception is raised if a task that created
one of the arguments failed.
"""
arguments = []
for (i, arg) in enumerate(serialized_args):
if isinstance(arg, ray.local_scheduler.ObjectID):
# get the object from the local object store
argument = worker.get_object([arg])[0]
if isinstance(argument, RayTaskError):
# If the result is a RayTaskError, then the task that created
# this object failed, and we should propagate the error message
# here.
raise RayGetArgumentError(function_name, i, arg, argument)
else:
# pass the argument by value
argument = arg
arguments.append(argument)
return arguments
def store_outputs_in_objstore(objectids, outputs, worker=global_worker):
"""Store the outputs of a remote function in the local object store.
This stores the values that were returned by a remote function in the local
object store. If any of the return values are object IDs, then these object
IDs are aliased with the object IDs that the scheduler assigned for the
return values. This is called by the worker that executes the remote
function.
Note:
The arguments objectids and outputs should have the same length.
Args:
objectids (List[ObjectID]): The object IDs that were assigned to the
outputs of the remote function call.
outputs (Tuple): The value returned by the remote function. If the
remote function was supposed to only return one value, then its
output was wrapped in a tuple with one element prior to being
passed into this function.
"""
for i in range(len(objectids)):
worker.put_object(objectids[i], outputs[i])
+29 -17
View File
@@ -30,6 +30,31 @@ def random_string():
return np.random.bytes(20)
def create_redis_client(redis_address):
redis_ip_address, redis_port = redis_address.split(":")
# For this command to work, some other client (on the same machine
# as Redis) must have run "CONFIG SET protected-mode no".
return redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
def push_error_to_all_drivers(redis_client, message):
"""Push an error message to all drivers.
Args:
redis_client: The redis client to use.
message: The error message to push.
"""
DRIVER_ID_LENGTH = 20
# We use a driver ID of all zeros to push an error message to all
# drivers.
driver_id = DRIVER_ID_LENGTH * b"\x00"
error_key = b"Error:" + driver_id + b":" + random_string()
# Create a Redis client.
redis_client.hmset(error_key, {"type": "worker_crash",
"message": message})
redis_client.rpush("ErrorKeys", error_key)
if __name__ == "__main__":
args = parser.parse_args()
info = {"node_ip_address": args.node_ip_address,
@@ -57,25 +82,12 @@ if __name__ == "__main__":
# task) should be caught and handled inside of the call to
# main_loop. If an exception is thrown here, then that means that
# there is some error that we didn't anticipate.
ray.worker.main_loop()
ray.worker.global_worker.main_loop()
except Exception as e:
traceback_str = traceback.format_exc() + error_explanation
DRIVER_ID_LENGTH = 20
# We use a driver ID of all zeros to push an error message to all
# drivers.
driver_id = DRIVER_ID_LENGTH * b"\x00"
error_key = b"Error:" + driver_id + b":" + random_string()
redis_ip_address, redis_port = args.redis_address.split(":")
# For this command to work, some other client (on the same machine
# as Redis) must have run "CONFIG SET protected-mode no".
redis_client = redis.StrictRedis(host=redis_ip_address,
port=int(redis_port))
redis_client.hmset(error_key, {"type": "worker_crash",
"message": traceback_str,
"note": ("This error is unexpected "
"and should not have "
"happened.")})
redis_client.rpush("ErrorKeys", error_key)
# Create a Redis client.
redis_client = create_redis_client(args.redis_address)
push_error_to_all_drivers(redis_client, traceback_str)
# TODO(rkn): Note that if the worker was in the middle of executing
# a task, then any worker or driver that is blocking in a get call
# and waiting for the output of that task will hang. We need to