mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 06:36:36 +08:00
Move worker methods into Worker class and expose more TaskSpec fields to Python. (#796)
* Move worker methods inside worker class. Move some helper methods from actor.py into utils.py and state.py. * Add more methods exposing task spec fields to Python. * Fix linting. * Fix error. * Remove unused code in default worker.
This commit is contained in:
committed by
Philipp Moritz
parent
52a27be364
commit
8c8258de20
+8
-134
@@ -6,15 +6,13 @@ import cloudpickle as pickle
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import numpy as np
|
||||
import redis
|
||||
import traceback
|
||||
|
||||
import ray.local_scheduler
|
||||
import ray.signature as signature
|
||||
import ray.worker
|
||||
from ray.utils import (FunctionProperties, binary_to_hex, hex_to_binary,
|
||||
random_string)
|
||||
from ray.utils import (FunctionProperties, random_string,
|
||||
select_local_scheduler)
|
||||
|
||||
|
||||
def random_actor_id():
|
||||
@@ -102,117 +100,6 @@ def fetch_and_register_actor(actor_class_key, worker):
|
||||
# for the actor.
|
||||
|
||||
|
||||
def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler, worker):
|
||||
"""Attempt to acquire GPUs on a particular local scheduler for an actor.
|
||||
|
||||
Args:
|
||||
num_gpus: The number of GPUs to acquire.
|
||||
driver_id: The ID of the driver responsible for creating the actor.
|
||||
local_scheduler: Information about the local scheduler.
|
||||
|
||||
Returns:
|
||||
True if the GPUs were successfully reserved and false otherwise.
|
||||
"""
|
||||
assert num_gpus != 0
|
||||
local_scheduler_id = local_scheduler["DBClientID"]
|
||||
local_scheduler_total_gpus = int(local_scheduler["NumGPUs"])
|
||||
|
||||
success = False
|
||||
|
||||
# Attempt to acquire GPU IDs atomically.
|
||||
with worker.redis_client.pipeline() as pipe:
|
||||
while True:
|
||||
try:
|
||||
# If this key is changed before the transaction below (the
|
||||
# multi/exec block), then the transaction will not take place.
|
||||
pipe.watch(local_scheduler_id)
|
||||
|
||||
# Figure out which GPUs are currently in use.
|
||||
result = worker.redis_client.hget(local_scheduler_id,
|
||||
"gpus_in_use")
|
||||
gpus_in_use = dict() if result is None else json.loads(
|
||||
result.decode("ascii"))
|
||||
num_gpus_in_use = 0
|
||||
for key in gpus_in_use:
|
||||
num_gpus_in_use += gpus_in_use[key]
|
||||
assert num_gpus_in_use <= local_scheduler_total_gpus
|
||||
|
||||
pipe.multi()
|
||||
|
||||
if local_scheduler_total_gpus - num_gpus_in_use >= num_gpus:
|
||||
# There are enough available GPUs, so try to reserve some.
|
||||
# We use the hex driver ID in hex as a dictionary key so
|
||||
# that the dictionary is JSON serializable.
|
||||
driver_id_hex = binary_to_hex(driver_id)
|
||||
if driver_id_hex not in gpus_in_use:
|
||||
gpus_in_use[driver_id_hex] = 0
|
||||
gpus_in_use[driver_id_hex] += num_gpus
|
||||
|
||||
# Stick the updated GPU IDs back in Redis
|
||||
pipe.hset(local_scheduler_id, "gpus_in_use",
|
||||
json.dumps(gpus_in_use))
|
||||
success = True
|
||||
|
||||
pipe.execute()
|
||||
# If a WatchError is not raised, then the operations should
|
||||
# have gone through atomically.
|
||||
break
|
||||
except redis.WatchError:
|
||||
# Another client must have changed the watched key between the
|
||||
# time we started WATCHing it and the pipeline's execution. We
|
||||
# should just retry.
|
||||
success = False
|
||||
continue
|
||||
|
||||
return success
|
||||
|
||||
|
||||
def select_local_scheduler(local_schedulers, num_gpus, worker):
|
||||
"""Select a local scheduler to assign this actor to.
|
||||
|
||||
Args:
|
||||
local_schedulers: A list of dictionaries of information about the local
|
||||
schedulers.
|
||||
num_gpus (int): The number of GPUs that must be reserved for this
|
||||
actor.
|
||||
|
||||
Returns:
|
||||
The ID of the local scheduler that has been chosen.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if no local scheduler can be found
|
||||
with sufficient resources.
|
||||
"""
|
||||
driver_id = worker.task_driver_id.id()
|
||||
|
||||
local_scheduler_id = None
|
||||
# Loop through all of the local schedulers in a random order.
|
||||
local_schedulers = np.random.permutation(local_schedulers)
|
||||
for local_scheduler in local_schedulers:
|
||||
if local_scheduler["NumCPUs"] < 1:
|
||||
continue
|
||||
if local_scheduler["NumGPUs"] < num_gpus:
|
||||
continue
|
||||
if num_gpus == 0:
|
||||
local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"])
|
||||
break
|
||||
else:
|
||||
# Try to reserve enough GPUs on this local scheduler.
|
||||
success = attempt_to_reserve_gpus(num_gpus, driver_id,
|
||||
local_scheduler, worker)
|
||||
if success:
|
||||
local_scheduler_id = hex_to_binary(
|
||||
local_scheduler["DBClientID"])
|
||||
break
|
||||
|
||||
if local_scheduler_id is None:
|
||||
raise Exception("Could not find a node with enough GPUs or other "
|
||||
"resources to create this actor. The local scheduler "
|
||||
"information is {}.".format(local_schedulers))
|
||||
|
||||
return local_scheduler_id
|
||||
|
||||
|
||||
def export_actor_class(class_id, Class, actor_method_names, worker):
|
||||
if worker.mode is None:
|
||||
raise NotImplemented("TODO(pcm): Cache actors")
|
||||
@@ -255,17 +142,10 @@ def export_actor(actor_id, class_id, actor_method_names, num_cpus, num_gpus,
|
||||
num_gpus=0,
|
||||
max_calls=0))
|
||||
|
||||
# Get a list of the local schedulers from the client table.
|
||||
client_table = ray.global_state.client_table()
|
||||
local_schedulers = []
|
||||
for ip_address, clients in client_table.items():
|
||||
for client in clients:
|
||||
if (client["ClientType"] == "local_scheduler" and
|
||||
not client["Deleted"]):
|
||||
local_schedulers.append(client)
|
||||
# Select a local scheduler for the actor.
|
||||
local_scheduler_id = select_local_scheduler(local_schedulers, num_gpus,
|
||||
worker)
|
||||
local_scheduler_id = select_local_scheduler(
|
||||
worker.task_driver_id.id(), ray.global_state.local_schedulers(),
|
||||
num_gpus, worker.redis_client)
|
||||
assert local_scheduler_id is not None
|
||||
|
||||
# We must put the actor information in Redis before publishing the actor
|
||||
@@ -274,17 +154,12 @@ def export_actor(actor_id, class_id, actor_method_names, num_cpus, num_gpus,
|
||||
worker.redis_client.hmset(key, {"class_id": class_id,
|
||||
"num_gpus": num_gpus})
|
||||
|
||||
# Really we should encode this message as a flatbuffer object. However,
|
||||
# we're having trouble getting that to work. It almost works, but in Python
|
||||
# 2.7, builder.CreateString fails on byte strings that contain characters
|
||||
# outside range(128).
|
||||
|
||||
# TODO(rkn): There is actually no guarantee that the local scheduler that
|
||||
# we are publishing to has already subscribed to the actor_notifications
|
||||
# channel. Therefore, this message may be missed and the workload will
|
||||
# hang. This is a bug.
|
||||
worker.redis_client.publish("actor_notifications",
|
||||
actor_id.id() + driver_id + local_scheduler_id)
|
||||
ray.utils.publish_actor_creation(actor_id.id(), driver_id,
|
||||
local_scheduler_id, worker.redis_client)
|
||||
|
||||
|
||||
def actor(*args, **kwargs):
|
||||
@@ -319,8 +194,7 @@ def make_actor(cls, num_cpus, num_gpus):
|
||||
args = signature.extend_args(function_signature, args, kwargs)
|
||||
|
||||
function_id = get_actor_method_function_id(attr)
|
||||
object_ids = ray.worker.global_worker.submit_task(function_id, "",
|
||||
args,
|
||||
object_ids = ray.worker.global_worker.submit_task(function_id, args,
|
||||
actor_id=actor_id)
|
||||
if len(object_ids) == 1:
|
||||
return object_ids[0]
|
||||
|
||||
@@ -642,6 +642,21 @@ class GlobalState(object):
|
||||
all_times.append(data["store_outputs_end"])
|
||||
return all_times
|
||||
|
||||
def local_schedulers(self):
|
||||
"""Get a list of live local schedulers.
|
||||
|
||||
Returns:
|
||||
A list of the live local schedulers.
|
||||
"""
|
||||
clients = self.client_table()
|
||||
local_schedulers = []
|
||||
for ip_address, client_list in clients.items():
|
||||
for client in client_list:
|
||||
if (client["ClientType"] == "local_scheduler" and
|
||||
not client["Deleted"]):
|
||||
local_schedulers.append(client)
|
||||
return local_schedulers
|
||||
|
||||
def workers(self):
|
||||
"""Get a dictionary mapping worker ID to worker information."""
|
||||
worker_keys = self.redis_client.keys("Worker*")
|
||||
@@ -666,6 +681,22 @@ class GlobalState(object):
|
||||
}
|
||||
return workers_data
|
||||
|
||||
def actors(self):
|
||||
actor_keys = self.redis_client.keys("Actor:*")
|
||||
actor_info = dict()
|
||||
for key in actor_keys:
|
||||
info = self.redis_client.hgetall(key)
|
||||
actor_id = key[len("Actor:"):]
|
||||
assert len(actor_id) == 20
|
||||
actor_info[binary_to_hex(actor_id)] = {
|
||||
"class_id": binary_to_hex(info[b"class_id"]),
|
||||
"driver_id": binary_to_hex(info[b"driver_id"]),
|
||||
"local_scheduler_id":
|
||||
binary_to_hex(info[b"local_scheduler_id"]),
|
||||
"num_gpus": int(info[b"num_gpus"]),
|
||||
"removed": decode(info[b"removed"]) == "True"}
|
||||
return actor_info
|
||||
|
||||
def _job_length(self):
|
||||
event_log_sets = self.redis_client.keys("event_log*")
|
||||
overall_smallest = sys.maxsize
|
||||
|
||||
@@ -4,7 +4,9 @@ from __future__ import print_function
|
||||
|
||||
import binascii
|
||||
import collections
|
||||
import json
|
||||
import numpy as np
|
||||
import redis
|
||||
import sys
|
||||
|
||||
import ray.local_scheduler
|
||||
@@ -65,3 +67,141 @@ FunctionProperties = collections.namedtuple("FunctionProperties",
|
||||
"num_gpus",
|
||||
"max_calls"])
|
||||
"""FunctionProperties: A named tuple storing remote functions information."""
|
||||
|
||||
|
||||
def attempt_to_reserve_gpus(num_gpus, driver_id, local_scheduler,
|
||||
redis_client):
|
||||
"""Attempt to acquire GPUs on a particular local scheduler for an actor.
|
||||
|
||||
Args:
|
||||
num_gpus: The number of GPUs to acquire.
|
||||
driver_id: The ID of the driver responsible for creating the actor.
|
||||
local_scheduler: Information about the local scheduler.
|
||||
redis_client: The redis client to use for interacting with Redis.
|
||||
|
||||
Returns:
|
||||
True if the GPUs were successfully reserved and false otherwise.
|
||||
"""
|
||||
assert num_gpus != 0
|
||||
local_scheduler_id = local_scheduler["DBClientID"]
|
||||
local_scheduler_total_gpus = int(local_scheduler["NumGPUs"])
|
||||
|
||||
success = False
|
||||
|
||||
# Attempt to acquire GPU IDs atomically.
|
||||
with redis_client.pipeline() as pipe:
|
||||
while True:
|
||||
try:
|
||||
# If this key is changed before the transaction below (the
|
||||
# multi/exec block), then the transaction will not take place.
|
||||
pipe.watch(local_scheduler_id)
|
||||
|
||||
# Figure out which GPUs are currently in use.
|
||||
result = redis_client.hget(local_scheduler_id, "gpus_in_use")
|
||||
gpus_in_use = dict() if result is None else json.loads(
|
||||
result.decode("ascii"))
|
||||
num_gpus_in_use = 0
|
||||
for key in gpus_in_use:
|
||||
num_gpus_in_use += gpus_in_use[key]
|
||||
assert num_gpus_in_use <= local_scheduler_total_gpus
|
||||
|
||||
pipe.multi()
|
||||
|
||||
if local_scheduler_total_gpus - num_gpus_in_use >= num_gpus:
|
||||
# There are enough available GPUs, so try to reserve some.
|
||||
# We use the hex driver ID in hex as a dictionary key so
|
||||
# that the dictionary is JSON serializable.
|
||||
driver_id_hex = binary_to_hex(driver_id)
|
||||
if driver_id_hex not in gpus_in_use:
|
||||
gpus_in_use[driver_id_hex] = 0
|
||||
gpus_in_use[driver_id_hex] += num_gpus
|
||||
|
||||
# Stick the updated GPU IDs back in Redis
|
||||
pipe.hset(local_scheduler_id, "gpus_in_use",
|
||||
json.dumps(gpus_in_use))
|
||||
success = True
|
||||
|
||||
pipe.execute()
|
||||
# If a WatchError is not raised, then the operations should
|
||||
# have gone through atomically.
|
||||
break
|
||||
except redis.WatchError:
|
||||
# Another client must have changed the watched key between the
|
||||
# time we started WATCHing it and the pipeline's execution. We
|
||||
# should just retry.
|
||||
success = False
|
||||
continue
|
||||
|
||||
return success
|
||||
|
||||
|
||||
def select_local_scheduler(driver_id, local_schedulers, num_gpus,
|
||||
redis_client):
|
||||
"""Select a local scheduler to assign this actor to.
|
||||
|
||||
Args:
|
||||
driver_id: The ID of the driver who the actor is for.
|
||||
local_schedulers: A list of dictionaries of information about the local
|
||||
schedulers.
|
||||
num_gpus (int): The number of GPUs that must be reserved for this
|
||||
actor.
|
||||
redis_client: The Redis client to use for interacting with Redis.
|
||||
|
||||
Returns:
|
||||
The ID of the local scheduler that has been chosen.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if no local scheduler can be found
|
||||
with sufficient resources.
|
||||
"""
|
||||
local_scheduler_id = None
|
||||
# Loop through all of the local schedulers in a random order.
|
||||
local_schedulers = np.random.permutation(local_schedulers)
|
||||
for local_scheduler in local_schedulers:
|
||||
if local_scheduler["NumCPUs"] < 1:
|
||||
continue
|
||||
if local_scheduler["NumGPUs"] < num_gpus:
|
||||
continue
|
||||
if num_gpus == 0:
|
||||
local_scheduler_id = hex_to_binary(local_scheduler["DBClientID"])
|
||||
break
|
||||
else:
|
||||
# Try to reserve enough GPUs on this local scheduler.
|
||||
success = attempt_to_reserve_gpus(num_gpus, driver_id,
|
||||
local_scheduler, redis_client)
|
||||
if success:
|
||||
local_scheduler_id = hex_to_binary(
|
||||
local_scheduler["DBClientID"])
|
||||
break
|
||||
|
||||
if local_scheduler_id is None:
|
||||
raise Exception("Could not find a node with enough GPUs or other "
|
||||
"resources to create this actor. The local scheduler "
|
||||
"information is {}.".format(local_schedulers))
|
||||
|
||||
return local_scheduler_id
|
||||
|
||||
|
||||
def publish_actor_creation(actor_id, driver_id, local_scheduler_id,
|
||||
redis_client):
|
||||
"""Publish a notification that an actor should be created.
|
||||
|
||||
This broadcast will be received by all of the local schedulers. The local
|
||||
scheduler whose ID is being broadcast will create the actor. Any other
|
||||
local schedulers that have already created the actor will kill it. All
|
||||
local schedulers will update their internal data structures to redirect
|
||||
tasks for this actor to the new local scheduler.
|
||||
|
||||
Args:
|
||||
actor_id: The ID of the actor involved.
|
||||
driver_id: The ID of the driver responsible for the actor.
|
||||
local_scheduler_id: The ID of the local scheduler that is suposed to
|
||||
create the actor.
|
||||
redis_client: The client used to interact with Redis.
|
||||
"""
|
||||
# Really we should encode this message as a flatbuffer object. However,
|
||||
# we're having trouble getting that to work. It almost works, but in Python
|
||||
# 2.7, builder.CreateString fails on byte strings that contain characters
|
||||
# outside range(128).
|
||||
redis_client.publish("actor_notifications",
|
||||
actor_id + driver_id + local_scheduler_id)
|
||||
|
||||
+262
-251
@@ -473,15 +473,14 @@ class Worker(object):
|
||||
assert final_results[i][0] == object_ids[i].id()
|
||||
return [result[1][0] for result in final_results]
|
||||
|
||||
def submit_task(self, function_id, func_name, args, actor_id=None):
|
||||
def submit_task(self, function_id, args, actor_id=None):
|
||||
"""Submit a remote task to the scheduler.
|
||||
|
||||
Tell the scheduler to schedule the execution of the function with name
|
||||
func_name with arguments args. Retrieve object IDs for the outputs of
|
||||
Tell the scheduler to schedule the execution of the function with ID
|
||||
function_id with arguments args. Retrieve object IDs for the outputs of
|
||||
the function from the scheduler and immediately return them.
|
||||
|
||||
Args:
|
||||
func_name (str): The name of the function to be executed.
|
||||
args (List[Any]): The arguments to pass into the function.
|
||||
Arguments can be object IDs or they can be values. If they are
|
||||
values, they must be serializable objecs.
|
||||
@@ -513,7 +512,8 @@ class Worker(object):
|
||||
function_properties.num_return_vals,
|
||||
self.current_task_id,
|
||||
self.task_index,
|
||||
actor_id, self.actor_counters[actor_id],
|
||||
actor_id,
|
||||
self.actor_counters[actor_id],
|
||||
[function_properties.num_cpus, function_properties.num_gpus])
|
||||
# Increment the worker's task index to track how many tasks have
|
||||
# been submitted by the current task so far.
|
||||
@@ -582,6 +582,260 @@ class Worker(object):
|
||||
"data": data})
|
||||
self.redis_client.rpush("ErrorKeys", error_key)
|
||||
|
||||
def _wait_for_function(self, function_id, driver_id, timeout=10):
|
||||
"""Wait until the function to be executed is present on this worker.
|
||||
|
||||
This method will simply loop until the import thread has imported the
|
||||
relevant function. If we spend too long in this loop, that may indicate
|
||||
a problem somewhere and we will push an error message to the user.
|
||||
|
||||
If this worker is an actor, then this will wait until the actor has
|
||||
been defined.
|
||||
|
||||
Args:
|
||||
is_actor (bool): True if this worker is an actor, and false
|
||||
otherwise.
|
||||
function_id (str): The ID of the function that we want to execute.
|
||||
driver_id (str): The ID of the driver to push the error message to
|
||||
if this times out.
|
||||
"""
|
||||
start_time = time.time()
|
||||
# Only send the warning once.
|
||||
warning_sent = False
|
||||
while True:
|
||||
with self.lock:
|
||||
if (self.actor_id == NIL_ACTOR_ID and
|
||||
(function_id.id() in self.functions[driver_id])):
|
||||
break
|
||||
elif self.actor_id != NIL_ACTOR_ID and (self.actor_id in
|
||||
self.actors):
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
warning_message = ("This worker was asked to execute a "
|
||||
"function that it does not have "
|
||||
"registered. You may have to restart "
|
||||
"Ray.")
|
||||
if not warning_sent:
|
||||
self.push_error_to_driver(driver_id,
|
||||
"wait_for_function",
|
||||
warning_message)
|
||||
warning_sent = True
|
||||
time.sleep(0.001)
|
||||
|
||||
def _get_arguments_for_execution(self, function_name, serialized_args):
|
||||
"""Retrieve the arguments for the remote function.
|
||||
|
||||
This retrieves the values for the arguments to the remote function that
|
||||
were passed in as object IDs. Argumens that were passed by value are
|
||||
not changed. This is called by the worker that is executing the remote
|
||||
function.
|
||||
|
||||
Args:
|
||||
function_name (str): The name of the remote function whose
|
||||
arguments are being retrieved.
|
||||
serialized_args (List): The arguments to the function. These are
|
||||
either strings representing serialized objects passed by value
|
||||
or they are ObjectIDs.
|
||||
|
||||
Returns:
|
||||
The retrieved arguments in addition to the arguments that were
|
||||
passed by value.
|
||||
|
||||
Raises:
|
||||
RayGetArgumentError: This exception is raised if a task that
|
||||
created one of the arguments failed.
|
||||
"""
|
||||
arguments = []
|
||||
for (i, arg) in enumerate(serialized_args):
|
||||
if isinstance(arg, ray.local_scheduler.ObjectID):
|
||||
# get the object from the local object store
|
||||
argument = self.get_object([arg])[0]
|
||||
if isinstance(argument, RayTaskError):
|
||||
# If the result is a RayTaskError, then the task that
|
||||
# created this object failed, and we should propagate the
|
||||
# error message here.
|
||||
raise RayGetArgumentError(function_name, i, arg, argument)
|
||||
else:
|
||||
# pass the argument by value
|
||||
argument = arg
|
||||
|
||||
arguments.append(argument)
|
||||
return arguments
|
||||
|
||||
def _store_outputs_in_objstore(self, objectids, outputs):
|
||||
"""Store the outputs of a remote function in the local object store.
|
||||
|
||||
This stores the values that were returned by a remote function in the
|
||||
local object store. If any of the return values are object IDs, then
|
||||
these object IDs are aliased with the object IDs that the scheduler
|
||||
assigned for the return values. This is called by the worker that
|
||||
executes the remote function.
|
||||
|
||||
Note:
|
||||
The arguments objectids and outputs should have the same length.
|
||||
|
||||
Args:
|
||||
objectids (List[ObjectID]): The object IDs that were assigned to
|
||||
the outputs of the remote function call.
|
||||
outputs (Tuple): The value returned by the remote function. If the
|
||||
remote function was supposed to only return one value, then its
|
||||
output was wrapped in a tuple with one element prior to being
|
||||
passed into this function.
|
||||
"""
|
||||
for i in range(len(objectids)):
|
||||
self.put_object(objectids[i], outputs[i])
|
||||
|
||||
def _process_task(self, task):
|
||||
"""Execute a task assigned to this worker.
|
||||
|
||||
This method deserializes a task from the scheduler, and attempts to
|
||||
execute the task. If the task succeeds, the outputs are stored in the
|
||||
local object store. If the task throws an exception, RayTaskError
|
||||
objects are stored in the object store to represent the failed task
|
||||
(these will be retrieved by calls to get or by subsequent tasks that
|
||||
use the outputs of this task).
|
||||
"""
|
||||
try:
|
||||
# The ID of the driver that this task belongs to. This is needed so
|
||||
# that if the task throws an exception, we propagate the error
|
||||
# message to the correct driver.
|
||||
self.task_driver_id = task.driver_id()
|
||||
self.current_task_id = task.task_id()
|
||||
self.current_function_id = task.function_id().id()
|
||||
self.task_index = 0
|
||||
self.put_index = 0
|
||||
function_id = task.function_id()
|
||||
args = task.arguments()
|
||||
return_object_ids = task.returns()
|
||||
function_name, function_executor = (self.functions
|
||||
[self.task_driver_id.id()]
|
||||
[function_id.id()])
|
||||
|
||||
# Get task arguments from the object store.
|
||||
with log_span("ray:task:get_arguments", worker=self):
|
||||
arguments = self._get_arguments_for_execution(function_name,
|
||||
args)
|
||||
|
||||
# Execute the task.
|
||||
with log_span("ray:task:execute", worker=self):
|
||||
if task.actor_id().id() == NIL_ACTOR_ID:
|
||||
outputs = function_executor.executor(arguments)
|
||||
else:
|
||||
outputs = function_executor(
|
||||
self.actors[task.actor_id().id()], *arguments)
|
||||
|
||||
# Store the outputs in the local object store.
|
||||
with log_span("ray:task:store_outputs", worker=self):
|
||||
if len(return_object_ids) == 1:
|
||||
outputs = (outputs,)
|
||||
self._store_outputs_in_objstore(return_object_ids, outputs)
|
||||
except Exception as e:
|
||||
# We determine whether the exception was caused by the call to
|
||||
# _get_arguments_for_execution or by the execution of the remote
|
||||
# function or by the call to _store_outputs_in_objstore. Depending
|
||||
# on which case occurred, we format the error message differently.
|
||||
# whether the variables "arguments" and "outputs" are defined.
|
||||
if "arguments" in locals() and "outputs" not in locals():
|
||||
if task.actor_id().id() == NIL_ACTOR_ID:
|
||||
# The error occurred during the task execution.
|
||||
traceback_str = format_error_message(
|
||||
traceback.format_exc(), task_exception=True)
|
||||
else:
|
||||
# The error occurred during the execution of an actor task.
|
||||
traceback_str = format_error_message(
|
||||
traceback.format_exc())
|
||||
elif "arguments" in locals() and "outputs" in locals():
|
||||
# The error occurred after the task executed.
|
||||
traceback_str = format_error_message(traceback.format_exc())
|
||||
else:
|
||||
# The error occurred before the task execution.
|
||||
if (isinstance(e, RayGetError) or
|
||||
isinstance(e, RayGetArgumentError)):
|
||||
# In this case, getting the task arguments failed.
|
||||
traceback_str = None
|
||||
else:
|
||||
traceback_str = traceback.format_exc()
|
||||
failure_object = RayTaskError(function_name, e, traceback_str)
|
||||
failure_objects = [failure_object for _
|
||||
in range(len(return_object_ids))]
|
||||
self._store_outputs_in_objstore(return_object_ids, failure_objects)
|
||||
# Log the error message.
|
||||
self.push_error_to_driver(self.task_driver_id.id(), "task",
|
||||
str(failure_object),
|
||||
data={"function_id": function_id.id(),
|
||||
"function_name": function_name})
|
||||
|
||||
def _wait_for_and_process_task(self, task):
|
||||
"""Wait for a task to be ready and process the task.
|
||||
|
||||
Args:
|
||||
task: The task to execute.
|
||||
"""
|
||||
function_id = task.function_id()
|
||||
# Wait until the function to be executed has actually been registered
|
||||
# on this worker. We will push warnings to the user if we spend too
|
||||
# long in this loop.
|
||||
with log_span("ray:wait_for_function", worker=self):
|
||||
self._wait_for_function(function_id, task.driver_id().id())
|
||||
|
||||
# Execute the task.
|
||||
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
|
||||
# warning to the user if we are waiting too long to acquire the lock
|
||||
# because that may indicate that the system is hanging, and it'd be
|
||||
# good to know where the system is hanging.
|
||||
log(event_type="ray:acquire_lock", kind=LOG_SPAN_START, worker=self)
|
||||
with self.lock:
|
||||
log(event_type="ray:acquire_lock", kind=LOG_SPAN_END,
|
||||
worker=self)
|
||||
|
||||
function_name, _ = (self.functions[task.driver_id().id()]
|
||||
[function_id.id()])
|
||||
contents = {"function_name": function_name,
|
||||
"task_id": task.task_id().hex(),
|
||||
"worker_id": binary_to_hex(self.worker_id)}
|
||||
with log_span("ray:task", contents=contents, worker=self):
|
||||
self._process_task(task)
|
||||
|
||||
# Push all of the log events to the global state store.
|
||||
flush_log()
|
||||
|
||||
# Increase the task execution counter.
|
||||
(self.num_task_executions[task.driver_id().id()]
|
||||
[function_id.id()]) += 1
|
||||
|
||||
reached_max_executions = (
|
||||
self.num_task_executions[task.driver_id().id()]
|
||||
[function_id.id()] ==
|
||||
self.function_properties[task.driver_id().id()]
|
||||
[function_id.id()].max_calls)
|
||||
if reached_max_executions:
|
||||
ray.worker.global_worker.local_scheduler_client.disconnect()
|
||||
os._exit(0)
|
||||
|
||||
def _get_next_task_from_local_scheduler(self):
|
||||
"""Get the next task from the local scheduler.
|
||||
|
||||
Returns:
|
||||
A task from the local scheduler.
|
||||
"""
|
||||
with log_span("ray:get_task", worker=self):
|
||||
task = self.local_scheduler_client.get_task()
|
||||
return task
|
||||
|
||||
def main_loop(self):
|
||||
"""The main loop a worker runs to receive and execute tasks."""
|
||||
|
||||
def exit(signum, frame):
|
||||
cleanup(worker=self)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, exit)
|
||||
|
||||
check_main_thread()
|
||||
while True:
|
||||
task = self._get_next_task_from_local_scheduler()
|
||||
self._wait_for_and_process_task(task)
|
||||
|
||||
|
||||
def get_gpu_ids():
|
||||
"""Get the IDs of the GPU that are available to the worker.
|
||||
@@ -1731,45 +1985,6 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
return ready_ids, remaining_ids
|
||||
|
||||
|
||||
def wait_for_function(function_id, driver_id, timeout=10,
|
||||
worker=global_worker):
|
||||
"""Wait until the function to be executed is present on this worker.
|
||||
|
||||
This method will simply loop until the import thread has imported the
|
||||
relevant function. If we spend too long in this loop, that may indicate a
|
||||
problem somewhere and we will push an error message to the user.
|
||||
|
||||
If this worker is an actor, then this will wait until the actor has been
|
||||
defined.
|
||||
|
||||
Args:
|
||||
is_actor (bool): True if this worker is an actor, and false otherwise.
|
||||
function_id (str): The ID of the function that we want to execute.
|
||||
driver_id (str): The ID of the driver to push the error message to if
|
||||
this times out.
|
||||
"""
|
||||
start_time = time.time()
|
||||
# Only send the warning once.
|
||||
warning_sent = False
|
||||
while True:
|
||||
with worker.lock:
|
||||
if (worker.actor_id == NIL_ACTOR_ID and
|
||||
(function_id.id() in worker.functions[driver_id])):
|
||||
break
|
||||
elif worker.actor_id != NIL_ACTOR_ID and (worker.actor_id in
|
||||
worker.actors):
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
warning_message = ("This worker was asked to execute a "
|
||||
"function that it does not have "
|
||||
"registered. You may have to restart Ray.")
|
||||
if not warning_sent:
|
||||
worker.push_error_to_driver(driver_id, "wait_for_function",
|
||||
warning_message)
|
||||
warning_sent = True
|
||||
time.sleep(0.001)
|
||||
|
||||
|
||||
def format_error_message(exception_message, task_exception=False):
|
||||
"""Improve the formatting of an exception thrown by a remote function.
|
||||
|
||||
@@ -1792,145 +2007,7 @@ def format_error_message(exception_message, task_exception=False):
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main_loop(worker=global_worker):
|
||||
"""The main loop a worker runs to receive and execute tasks."""
|
||||
|
||||
def exit(signum, frame):
|
||||
cleanup(worker=worker)
|
||||
sys.exit(0)
|
||||
|
||||
signal.signal(signal.SIGTERM, exit)
|
||||
|
||||
def process_task(task):
|
||||
"""Execute a task assigned to this worker.
|
||||
|
||||
This method deserializes a task from the scheduler, and attempts to
|
||||
execute the task. If the task succeeds, the outputs are stored in the
|
||||
local object store. If the task throws an exception, RayTaskError
|
||||
objects are stored in the object store to represent the failed task
|
||||
(these will be retrieved by calls to get or by subsequent tasks that
|
||||
use the outputs of this task).
|
||||
"""
|
||||
try:
|
||||
# The ID of the driver that this task belongs to. This is needed so
|
||||
# that if the task throws an exception, we propagate the error
|
||||
# message to the correct driver.
|
||||
worker.task_driver_id = task.driver_id()
|
||||
worker.current_task_id = task.task_id()
|
||||
worker.current_function_id = task.function_id().id()
|
||||
worker.task_index = 0
|
||||
worker.put_index = 0
|
||||
function_id = task.function_id()
|
||||
args = task.arguments()
|
||||
return_object_ids = task.returns()
|
||||
function_name, function_executor = (worker.functions
|
||||
[worker.task_driver_id.id()]
|
||||
[function_id.id()])
|
||||
|
||||
# Get task arguments from the object store.
|
||||
with log_span("ray:task:get_arguments", worker=worker):
|
||||
arguments = get_arguments_for_execution(function_name, args,
|
||||
worker)
|
||||
|
||||
# Execute the task.
|
||||
with log_span("ray:task:execute", worker=worker):
|
||||
if task.actor_id().id() == NIL_ACTOR_ID:
|
||||
outputs = function_executor.executor(arguments)
|
||||
else:
|
||||
outputs = function_executor(
|
||||
worker.actors[task.actor_id().id()], *arguments)
|
||||
|
||||
# Store the outputs in the local object store.
|
||||
with log_span("ray:task:store_outputs", worker=worker):
|
||||
if len(return_object_ids) == 1:
|
||||
outputs = (outputs,)
|
||||
store_outputs_in_objstore(return_object_ids, outputs, worker)
|
||||
except Exception as e:
|
||||
# We determine whether the exception was caused by the call to
|
||||
# get_arguments_for_execution or by the execution of the remote
|
||||
# function or by the call to store_outputs_in_objstore. Depending
|
||||
# on which case occurred, we format the error message differently.
|
||||
# whether the variables "arguments" and "outputs" are defined.
|
||||
if "arguments" in locals() and "outputs" not in locals():
|
||||
if task.actor_id().id() == NIL_ACTOR_ID:
|
||||
# The error occurred during the task execution.
|
||||
traceback_str = format_error_message(
|
||||
traceback.format_exc(), task_exception=True)
|
||||
else:
|
||||
# The error occurred during the execution of an actor task.
|
||||
traceback_str = format_error_message(
|
||||
traceback.format_exc())
|
||||
elif "arguments" in locals() and "outputs" in locals():
|
||||
# The error occurred after the task executed.
|
||||
traceback_str = format_error_message(traceback.format_exc())
|
||||
else:
|
||||
# The error occurred before the task execution.
|
||||
if (isinstance(e, RayGetError) or
|
||||
isinstance(e, RayGetArgumentError)):
|
||||
# In this case, getting the task arguments failed.
|
||||
traceback_str = None
|
||||
else:
|
||||
traceback_str = traceback.format_exc()
|
||||
failure_object = RayTaskError(function_name, e, traceback_str)
|
||||
failure_objects = [failure_object for _
|
||||
in range(len(return_object_ids))]
|
||||
store_outputs_in_objstore(return_object_ids, failure_objects,
|
||||
worker)
|
||||
# Log the error message.
|
||||
worker.push_error_to_driver(worker.task_driver_id.id(), "task",
|
||||
str(failure_object),
|
||||
data={"function_id": function_id.id(),
|
||||
"function_name": function_name})
|
||||
|
||||
check_main_thread()
|
||||
while True:
|
||||
with log_span("ray:get_task", worker=worker):
|
||||
task = worker.local_scheduler_client.get_task()
|
||||
|
||||
function_id = task.function_id()
|
||||
# Wait until the function to be executed has actually been registered
|
||||
# on this worker. We will push warnings to the user if we spend too
|
||||
# long in this loop.
|
||||
with log_span("ray:wait_for_function", worker=worker):
|
||||
wait_for_function(function_id, task.driver_id().id(),
|
||||
worker=worker)
|
||||
|
||||
# Execute the task.
|
||||
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
|
||||
# warning to the user if we are waiting too long to acquire the lock
|
||||
# because that may indicate that the system is hanging, and it'd be
|
||||
# good to know where the system is hanging.
|
||||
log(event_type="ray:acquire_lock", kind=LOG_SPAN_START, worker=worker)
|
||||
with worker.lock:
|
||||
log(event_type="ray:acquire_lock", kind=LOG_SPAN_END,
|
||||
worker=worker)
|
||||
|
||||
function_name, _ = (worker.functions[task.driver_id().id()]
|
||||
[function_id.id()])
|
||||
contents = {"function_name": function_name,
|
||||
"task_id": task.task_id().hex(),
|
||||
"worker_id": binary_to_hex(worker.worker_id)}
|
||||
with log_span("ray:task", contents=contents, worker=worker):
|
||||
process_task(task)
|
||||
|
||||
# Push all of the log events to the global state store.
|
||||
flush_log()
|
||||
|
||||
# Increase the task execution counter.
|
||||
(worker.num_task_executions[task.driver_id().id()]
|
||||
[function_id.id()]) += 1
|
||||
|
||||
reached_max_executions = (
|
||||
worker.num_task_executions[task.driver_id().id()]
|
||||
[function_id.id()] ==
|
||||
worker.function_properties[task.driver_id().id()]
|
||||
[function_id.id()].max_calls)
|
||||
if reached_max_executions:
|
||||
ray.worker.global_worker.local_scheduler_client.disconnect()
|
||||
os._exit(0)
|
||||
|
||||
|
||||
def _submit_task(function_id, func_name, args, worker=global_worker):
|
||||
def _submit_task(function_id, args, worker=global_worker):
|
||||
"""This is a wrapper around worker.submit_task.
|
||||
|
||||
We use this wrapper so that in the remote decorator, we can call
|
||||
@@ -1938,7 +2015,7 @@ def _submit_task(function_id, func_name, args, worker=global_worker):
|
||||
attempt to serialize remote functions, we don't attempt to serialize the
|
||||
worker object, which cannot be serialized.
|
||||
"""
|
||||
return worker.submit_task(function_id, func_name, args)
|
||||
return worker.submit_task(function_id, args)
|
||||
|
||||
|
||||
def _mode(worker=global_worker):
|
||||
@@ -2081,7 +2158,7 @@ def remote(*args, **kwargs):
|
||||
# immutable remote objects.
|
||||
result = func(*copy.deepcopy(args))
|
||||
return result
|
||||
objectids = _submit_task(function_id, func_name, args)
|
||||
objectids = _submit_task(function_id, args)
|
||||
if len(objectids) == 1:
|
||||
return objectids[0]
|
||||
elif len(objectids) > 1:
|
||||
@@ -2157,69 +2234,3 @@ def remote(*args, **kwargs):
|
||||
assert "function_id" not in kwargs
|
||||
return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
|
||||
max_calls)
|
||||
|
||||
|
||||
def get_arguments_for_execution(function_name, serialized_args,
|
||||
worker=global_worker):
|
||||
"""Retrieve the arguments for the remote function.
|
||||
|
||||
This retrieves the values for the arguments to the remote function that
|
||||
were passed in as object IDs. Argumens that were passed by value are not
|
||||
changed. This is called by the worker that is executing the remote
|
||||
function.
|
||||
|
||||
Args:
|
||||
function_name (str): The name of the remote function whose arguments
|
||||
are being retrieved.
|
||||
serialized_args (List): The arguments to the function. These are either
|
||||
strings representing serialized objects passed by value or they are
|
||||
ObjectIDs.
|
||||
|
||||
Returns:
|
||||
The retrieved arguments in addition to the arguments that were passed
|
||||
by value.
|
||||
|
||||
Raises:
|
||||
RayGetArgumentError: This exception is raised if a task that created
|
||||
one of the arguments failed.
|
||||
"""
|
||||
arguments = []
|
||||
for (i, arg) in enumerate(serialized_args):
|
||||
if isinstance(arg, ray.local_scheduler.ObjectID):
|
||||
# get the object from the local object store
|
||||
argument = worker.get_object([arg])[0]
|
||||
if isinstance(argument, RayTaskError):
|
||||
# If the result is a RayTaskError, then the task that created
|
||||
# this object failed, and we should propagate the error message
|
||||
# here.
|
||||
raise RayGetArgumentError(function_name, i, arg, argument)
|
||||
else:
|
||||
# pass the argument by value
|
||||
argument = arg
|
||||
|
||||
arguments.append(argument)
|
||||
return arguments
|
||||
|
||||
|
||||
def store_outputs_in_objstore(objectids, outputs, worker=global_worker):
|
||||
"""Store the outputs of a remote function in the local object store.
|
||||
|
||||
This stores the values that were returned by a remote function in the local
|
||||
object store. If any of the return values are object IDs, then these object
|
||||
IDs are aliased with the object IDs that the scheduler assigned for the
|
||||
return values. This is called by the worker that executes the remote
|
||||
function.
|
||||
|
||||
Note:
|
||||
The arguments objectids and outputs should have the same length.
|
||||
|
||||
Args:
|
||||
objectids (List[ObjectID]): The object IDs that were assigned to the
|
||||
outputs of the remote function call.
|
||||
outputs (Tuple): The value returned by the remote function. If the
|
||||
remote function was supposed to only return one value, then its
|
||||
output was wrapped in a tuple with one element prior to being
|
||||
passed into this function.
|
||||
"""
|
||||
for i in range(len(objectids)):
|
||||
worker.put_object(objectids[i], outputs[i])
|
||||
|
||||
@@ -30,6 +30,31 @@ def random_string():
|
||||
return np.random.bytes(20)
|
||||
|
||||
|
||||
def create_redis_client(redis_address):
|
||||
redis_ip_address, redis_port = redis_address.split(":")
|
||||
# For this command to work, some other client (on the same machine
|
||||
# as Redis) must have run "CONFIG SET protected-mode no".
|
||||
return redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
|
||||
|
||||
|
||||
def push_error_to_all_drivers(redis_client, message):
|
||||
"""Push an error message to all drivers.
|
||||
|
||||
Args:
|
||||
redis_client: The redis client to use.
|
||||
message: The error message to push.
|
||||
"""
|
||||
DRIVER_ID_LENGTH = 20
|
||||
# We use a driver ID of all zeros to push an error message to all
|
||||
# drivers.
|
||||
driver_id = DRIVER_ID_LENGTH * b"\x00"
|
||||
error_key = b"Error:" + driver_id + b":" + random_string()
|
||||
# Create a Redis client.
|
||||
redis_client.hmset(error_key, {"type": "worker_crash",
|
||||
"message": message})
|
||||
redis_client.rpush("ErrorKeys", error_key)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
info = {"node_ip_address": args.node_ip_address,
|
||||
@@ -57,25 +82,12 @@ if __name__ == "__main__":
|
||||
# task) should be caught and handled inside of the call to
|
||||
# main_loop. If an exception is thrown here, then that means that
|
||||
# there is some error that we didn't anticipate.
|
||||
ray.worker.main_loop()
|
||||
ray.worker.global_worker.main_loop()
|
||||
except Exception as e:
|
||||
traceback_str = traceback.format_exc() + error_explanation
|
||||
DRIVER_ID_LENGTH = 20
|
||||
# We use a driver ID of all zeros to push an error message to all
|
||||
# drivers.
|
||||
driver_id = DRIVER_ID_LENGTH * b"\x00"
|
||||
error_key = b"Error:" + driver_id + b":" + random_string()
|
||||
redis_ip_address, redis_port = args.redis_address.split(":")
|
||||
# For this command to work, some other client (on the same machine
|
||||
# as Redis) must have run "CONFIG SET protected-mode no".
|
||||
redis_client = redis.StrictRedis(host=redis_ip_address,
|
||||
port=int(redis_port))
|
||||
redis_client.hmset(error_key, {"type": "worker_crash",
|
||||
"message": traceback_str,
|
||||
"note": ("This error is unexpected "
|
||||
"and should not have "
|
||||
"happened.")})
|
||||
redis_client.rpush("ErrorKeys", error_key)
|
||||
# Create a Redis client.
|
||||
redis_client = create_redis_client(args.redis_address)
|
||||
push_error_to_all_drivers(redis_client, traceback_str)
|
||||
# TODO(rkn): Note that if the worker was in the middle of executing
|
||||
# a task, then any worker or driver that is blocking in a get call
|
||||
# and waiting for the output of that task will hang. We need to
|
||||
|
||||
Reference in New Issue
Block a user