mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
b98a63fd3a
* Change plasma_get to take a timeout and an array of object IDs. * Address comments. * Bug fix related to computing object hashes. * Add test. * Fix file descriptor leak. * Fix valgrind. * Formatting. * Remove call to plasma_contains from the plasma client. Use timeout internally in ray.get. * small fixes
1753 lines
78 KiB
Python
1753 lines
78 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import json
|
|
import hashlib
|
|
import os
|
|
import sys
|
|
import time
|
|
import traceback
|
|
import copy
|
|
import funcsigs
|
|
import numpy as np
|
|
import colorama
|
|
import atexit
|
|
import random
|
|
import redis
|
|
import threading
|
|
import string
|
|
|
|
# Ray modules
|
|
import ray.pickling as pickling
|
|
import ray.serialization as serialization
|
|
import ray.services as services
|
|
import numbuf
|
|
import photon
|
|
import plasma
|
|
|
|
SCRIPT_MODE = 0
|
|
WORKER_MODE = 1
|
|
PYTHON_MODE = 2
|
|
SILENT_MODE = 3
|
|
|
|
LOG_POINT = 0
|
|
LOG_SPAN_START = 1
|
|
LOG_SPAN_END = 2
|
|
|
|
def random_string():
|
|
return np.random.bytes(20)
|
|
|
|
def random_object_id():
|
|
return photon.ObjectID(random_string())
|
|
|
|
class FunctionID(object):
|
|
def __init__(self, function_id):
|
|
self.function_id = function_id
|
|
|
|
def id(self):
|
|
return self.function_id
|
|
|
|
contained_objectids = []
|
|
def numbuf_serialize(value):
|
|
"""This serializes a value and tracks the object IDs inside the value.
|
|
|
|
We also define a custom ObjectID serializer which also closes over the global
|
|
variable contained_objectids, and whenever the custom serializer is called, it
|
|
adds the releevant ObjectID to the list contained_objectids. The list
|
|
contained_objectids should be reset between calls to numbuf_serialize.
|
|
|
|
Args:
|
|
value: A Python object that will be serialized.
|
|
|
|
Returns:
|
|
The serialized object.
|
|
"""
|
|
assert len(contained_objectids) == 0, "This should be unreachable."
|
|
return numbuf.serialize_list([value])
|
|
|
|
class RayTaskError(Exception):
|
|
"""An object used internally to represent a task that threw an exception.
|
|
|
|
If a task throws an exception during execution, a RayTaskError is stored in
|
|
the object store for each of the task's outputs. When an object is retrieved
|
|
from the object store, the Python method that retrieved it checks to see if
|
|
the object is a RayTaskError and if it is then an exception is thrown
|
|
propagating the error message.
|
|
|
|
Currently, we either use the exception attribute or the traceback attribute
|
|
but not both.
|
|
|
|
Attributes:
|
|
function_name (str): The name of the function that failed and produced the
|
|
RayTaskError.
|
|
exception (Exception): The exception object thrown by the failed task.
|
|
traceback_str (str): The traceback from the exception.
|
|
"""
|
|
|
|
def __init__(self, function_name, exception, traceback_str):
|
|
"""Initialize a RayTaskError."""
|
|
self.function_name = function_name
|
|
if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError):
|
|
self.exception = exception
|
|
else:
|
|
self.exception = None
|
|
self.traceback_str = traceback_str
|
|
|
|
def __str__(self):
|
|
"""Format a RayTaskError as a string."""
|
|
if self.traceback_str is None:
|
|
# This path is taken if getting the task arguments failed.
|
|
return "Remote function {}{}{} failed with:\n\n{}".format(colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.exception)
|
|
else:
|
|
# This path is taken if the task execution failed.
|
|
return "Remote function {}{}{} failed with:\n\n{}".format(colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.traceback_str)
|
|
|
|
class RayGetError(Exception):
|
|
"""An exception used when get is called on an output of a failed task.
|
|
|
|
Attributes:
|
|
objectid (lib.ObjectID): The ObjectID that get was called on.
|
|
task_error (RayTaskError): The RayTaskError object created by the failed
|
|
task.
|
|
"""
|
|
|
|
def __init__(self, objectid, task_error):
|
|
"""Initialize a RayGetError object."""
|
|
self.objectid = objectid
|
|
self.task_error = task_error
|
|
|
|
def __str__(self):
|
|
"""Format a RayGetError as a string."""
|
|
return "Could not get objectid {}. It was created by remote function {}{}{} which failed with:\n\n{}".format(self.objectid, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
|
|
|
|
class RayGetArgumentError(Exception):
|
|
"""An exception used when a task's argument was produced by a failed task.
|
|
|
|
Attributes:
|
|
argument_index (int): The index (zero indexed) of the failed argument in
|
|
present task's remote function call.
|
|
function_name (str): The name of the function for the current task.
|
|
objectid (lib.ObjectID): The ObjectID that was passed in as the argument.
|
|
task_error (RayTaskError): The RayTaskError object created by the failed
|
|
task.
|
|
"""
|
|
|
|
def __init__(self, function_name, argument_index, objectid, task_error):
|
|
"""Initialize a RayGetArgumentError object."""
|
|
self.argument_index = argument_index
|
|
self.function_name = function_name
|
|
self.objectid = objectid
|
|
self.task_error = task_error
|
|
|
|
def __str__(self):
|
|
"""Format a RayGetArgumentError as a string."""
|
|
return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
|
|
|
|
|
|
class EnvironmentVariable(object):
|
|
"""An Python object that can be shared between tasks.
|
|
|
|
Attributes:
|
|
initializer (Callable[[], object]): A function used to create and initialize
|
|
the environment variable.
|
|
reinitializer (Optional[Callable[[object], object]]): An optional function
|
|
used to reinitialize the environment variable after it has been used. This
|
|
argument can be used as an optimization if there is a fast way to
|
|
reinitialize the state of the variable other than rerunning the
|
|
initializer.
|
|
"""
|
|
|
|
def __init__(self, initializer, reinitializer=None):
|
|
"""Initialize an EnvironmentVariable object."""
|
|
if not callable(initializer):
|
|
raise Exception("When creating an EnvironmentVariable, initializer must be a function.")
|
|
self.initializer = initializer
|
|
if reinitializer is None:
|
|
# If no reinitializer is passed in, use a wrapped version of the initializer.
|
|
reinitializer = lambda value: initializer()
|
|
if not callable(reinitializer):
|
|
raise Exception("When creating an EnvironmentVariable, reinitializer must be a function.")
|
|
self.reinitializer = reinitializer
|
|
|
|
class RayEnvironmentVariables(object):
|
|
"""An object used to store Python variables that are shared between tasks.
|
|
|
|
Each worker process will have a single RayEnvironmentVariables object. This
|
|
class serves two purposes. First, some objects are not serializable, and so
|
|
the code that creates those objects must be run on the worker that uses them.
|
|
This class is responsible for running the code that creates those objects.
|
|
Second, some of these objects are expensive to create, and so they should be
|
|
shared between tasks. However, if a task mutates a variable that is shared
|
|
between tasks, then the behavior of the overall program may be
|
|
nondeterministic (it could depend on scheduling decisions). To fix this, if a
|
|
task uses a one of these shared objects, then that shared object will be
|
|
reinitialized after the task finishes. Since the initialization may be
|
|
expensive, the user can pass in custom reinitialization code that resets the
|
|
state of the shared variable to the way it was after initialization. If the
|
|
reinitialization code does not do this, then the behavior of the overall
|
|
program is undefined.
|
|
|
|
Attributes:
|
|
_names (List[str]): A list of the names of all the environment variables.
|
|
_reinitializers (Dict[str, Callable]): A dictionary mapping the name of the
|
|
environment variables to the corresponding reinitializer.
|
|
_running_remote_function_locally (bool): A flag used to indicate if a remote
|
|
function is running locally on the driver so that we can simulate the same
|
|
behavior as running a remote function remotely.
|
|
_environment_variables: A dictionary mapping the name of an environment
|
|
variable to the value of the environment variable.
|
|
_local_mode_environment_variables: A copy of _environment_variables used on
|
|
the driver when running remote functions locally on the driver. This is
|
|
needed because there are two ways in which environment variables can be
|
|
used on the driver. The first is that the driver's copy can be
|
|
manipulated. This copy is never reset (think of the driver as a single
|
|
long-running task). The second way is that a remote function can be run
|
|
locally on the driver, and this remote function needs access to a copy of
|
|
the environment variable, and that copy must be reinitialized after use.
|
|
_cached_environment_variables (List[Tuple[str, EnvironmentVariable]]): A
|
|
list of pairs. The first element of each pair is the name of an
|
|
environment variable, and the second element is the EnvironmentVariable
|
|
object. This list is used to store environment variables that are defined
|
|
before the driver is connected. Once the driver is connected, these
|
|
variables will be exported.
|
|
_used (List[str]): A list of the names of all the environment variables that
|
|
have been accessed within the scope of the current task. This is reset to
|
|
the empty list after each task.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize an RayEnvironmentVariables object."""
|
|
self._names = set()
|
|
self._reinitializers = {}
|
|
self._running_remote_function_locally = False
|
|
self._environment_variables = {}
|
|
self._local_mode_environment_variables = {}
|
|
self._cached_environment_variables = []
|
|
self._used = set()
|
|
self._slots = ("_names", "_reinitializers", "_running_remote_function_locally", "_environment_variables", "_local_mode_environment_variables", "_cached_environment_variables", "_used", "_slots", "_create_environment_variable", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__")
|
|
# CHECKPOINT: Attributes must not be added after _slots. The above attributes are protected from deletion.
|
|
|
|
def _create_environment_variable(self, name, environment_variable):
|
|
"""Create an environment variable locally.
|
|
|
|
Args:
|
|
name (str): The name of the environment variable.
|
|
environment_variable (EnvironmentVariable): The environment variable
|
|
object to use to create the environment variable variable.
|
|
"""
|
|
self._names.add(name)
|
|
self._reinitializers[name] = environment_variable.reinitializer
|
|
self._environment_variables[name] = environment_variable.initializer()
|
|
# We create a second copy of the environment variable on the driver to use
|
|
# inside of remote functions that run locally. This occurs when we start Ray
|
|
# in PYTHON_MODE and when we call a remote function locally.
|
|
if _mode() in [SCRIPT_MODE, SILENT_MODE, PYTHON_MODE]:
|
|
self._local_mode_environment_variables[name] = environment_variable.initializer()
|
|
|
|
def _reinitialize(self):
|
|
"""Reinitialize the environment variables that the current task used."""
|
|
for name in self._used:
|
|
current_value = self._environment_variables[name]
|
|
new_value = self._reinitializers[name](current_value)
|
|
# If we are on the driver, reset the copy of the environment variable in
|
|
# the _local_mode_environment_variables dictionary.
|
|
if _mode() in [SCRIPT_MODE, SILENT_MODE, PYTHON_MODE]:
|
|
assert self._running_remote_function_locally
|
|
self._local_mode_environment_variables[name] = new_value
|
|
else:
|
|
self._environment_variables[name] = new_value
|
|
self._used.clear() # Reset the _used list.
|
|
|
|
def __getattribute__(self, name):
|
|
"""Get an attribute. This handles environment variables as a special case.
|
|
|
|
When __getattribute__ is called with the name of an environment variable,
|
|
that name is added to the list of variables that were used in the current
|
|
task.
|
|
|
|
Args:
|
|
name (str): The name of the attribute to get.
|
|
"""
|
|
if name == "_slots":
|
|
return object.__getattribute__(self, name)
|
|
if name in self._slots:
|
|
return object.__getattribute__(self, name)
|
|
# Handle various fields that are not environment variables.
|
|
if name not in self._names:
|
|
return object.__getattribute__(self, name)
|
|
# Make a note of the fact that the environment variable has been used.
|
|
if name in self._names and name not in self._used:
|
|
self._used.add(name)
|
|
if self._running_remote_function_locally:
|
|
return self._local_mode_environment_variables[name]
|
|
else:
|
|
return self._environment_variables[name]
|
|
|
|
def __setattr__(self, name, value):
|
|
"""Set an attribute. This handles environment variables as a special case.
|
|
|
|
This is used to create environment variables. When it is called, it runs the
|
|
function for initializing the variable to create the variable. If this is
|
|
called on the driver, then the functions for initializing and reinitializing
|
|
the variable are shipped to the workers.
|
|
|
|
If this is called before ray.init has been run, then the environment
|
|
variable will be cached and it will be created and exported when connect is
|
|
called.
|
|
|
|
Args:
|
|
name (str): The name of the attribute to set. This is either a whitelisted
|
|
name or it is treated as the name of an environment variable.
|
|
value: If name is a whitelisted name, then value can be any value. If name
|
|
is the name of an environment variable, then this is an
|
|
EnvironmentVariable object.
|
|
"""
|
|
try:
|
|
slots = self._slots
|
|
except AttributeError:
|
|
slots = ()
|
|
if slots == ():
|
|
return object.__setattr__(self, name, value)
|
|
if name in slots:
|
|
return object.__setattr__(self, name, value)
|
|
environment_variable = value
|
|
if not issubclass(type(environment_variable), EnvironmentVariable):
|
|
raise Exception("To set an environment variable, you must pass in an EnvironmentVariable object")
|
|
# If ray.init has not been called, cache the environment variable to export
|
|
# later. Otherwise, export the environment variable to the workers and
|
|
# define it locally.
|
|
if _mode() is None:
|
|
self._cached_environment_variables.append((name, environment_variable))
|
|
else:
|
|
# If we are on the driver, export the environment variable to all the
|
|
# workers.
|
|
if _mode() in [SCRIPT_MODE, SILENT_MODE]:
|
|
_export_environment_variable(name, environment_variable)
|
|
# Define the environment variable locally.
|
|
self._create_environment_variable(name, environment_variable)
|
|
# Create an empty attribute with the name of the environment variable.
|
|
# This allows the Python interpreter to do tab complete properly.
|
|
object.__setattr__(self, name, None)
|
|
|
|
def __delattr__(self, name):
|
|
"""We do not allow attributes of RayEnvironmentVariables to be deleted.
|
|
|
|
Args:
|
|
name (str): The name of the attribute to delete.
|
|
"""
|
|
raise Exception("Attempted deletion of attribute {}. Attributes of a RayEnvironmentVariables object may not be deleted.".format(name))
|
|
|
|
class Worker(object):
|
|
"""A class used to define the control flow of a worker process.
|
|
|
|
Note:
|
|
The methods in this class are considered unexposed to the user. The
|
|
functions outside of this class are considered exposed.
|
|
|
|
Attributes:
|
|
functions (Dict[str, Callable]): A dictionary mapping the name of a remote
|
|
function to the remote function itself. This is the set of remote
|
|
functions that can be executed by this worker.
|
|
connected (bool): True if Ray has been started and False otherwise.
|
|
mode: The mode of the worker. One of SCRIPT_MODE, PYTHON_MODE, SILENT_MODE,
|
|
and WORKER_MODE.
|
|
cached_remote_functions (List[Tuple[str, str]]): A list of pairs
|
|
representing the remote functions that were defined before he worker
|
|
called connect. The first element is the name of the remote function, and
|
|
the second element is the serialized remote function. When the worker
|
|
eventually does call connect, if it is a driver, it will export these
|
|
functions to the scheduler. If cached_remote_functions is None, that means
|
|
that connect has been called already.
|
|
cached_functions_to_run (List): A list of functions to run on all of the
|
|
workers that should be exported as soon as connect is called.
|
|
driver_export_counter (int): The number of exports that the driver has
|
|
exported. This is only used on the driver.
|
|
worker_import_counter (int): The number of exports that the worker has
|
|
imported so far. This is only used on the workers.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize a Worker object."""
|
|
self.functions = {}
|
|
self.num_return_vals = {}
|
|
self.function_names = {}
|
|
self.function_export_counters = {}
|
|
self.connected = False
|
|
self.mode = None
|
|
self.cached_remote_functions = []
|
|
self.cached_functions_to_run = []
|
|
self.driver_export_counter = 0
|
|
self.worker_import_counter = 0
|
|
|
|
def set_mode(self, mode):
|
|
"""Set the mode of the worker.
|
|
|
|
The mode SCRIPT_MODE should be used if this Worker is a driver that is being
|
|
run as a Python script or interactively in a shell. It will print
|
|
information about task failures.
|
|
|
|
The mode WORKER_MODE should be used if this Worker is not a driver. It will
|
|
not print information about tasks.
|
|
|
|
The mode PYTHON_MODE should be used if this Worker is a driver and if you
|
|
want to run the driver in a manner equivalent to serial Python for debugging
|
|
purposes. It will not send remote function calls to the scheduler and will
|
|
insead execute them in a blocking fashion.
|
|
|
|
The mode SILENT_MODE should be used only during testing. It does not print
|
|
any information about errors because some of the tests intentionally fail.
|
|
|
|
args:
|
|
mode: One of SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, and SILENT_MODE.
|
|
"""
|
|
self.mode = mode
|
|
colorama.init()
|
|
|
|
def put_object(self, objectid, value):
|
|
"""Put value in the local object store with object id objectid.
|
|
|
|
This assumes that the value for objectid has not yet been placed in the
|
|
local object store.
|
|
|
|
Args:
|
|
objectid (object_id.ObjectID): The object ID of the value to be put.
|
|
value (serializable object): The value to put in the object store.
|
|
"""
|
|
# Serialize and put the object in the object store.
|
|
try:
|
|
numbuf.store_list(objectid.id(), self.plasma_client.conn, [value])
|
|
except plasma.plasma_object_exists_error as e:
|
|
# The object already exists in the object store, so there is no need to
|
|
# add it again. TODO(rkn): We need to compare the hashes and make sure
|
|
# that the objects are in fact the same. We also should return an error
|
|
# code to the caller instead of printing a message.
|
|
print("This object already exists in the object store.")
|
|
return
|
|
global contained_objectids
|
|
# Optionally do something with the contained_objectids here.
|
|
contained_objectids = []
|
|
|
|
def get_object(self, object_ids):
|
|
"""Get the value or values in the local object store associated with object_ids.
|
|
|
|
Return the values from the local object store for object_ids. This will block
|
|
until all the values for object_ids have been written to the local object store.
|
|
|
|
Args:
|
|
object_ids (List[object_id.ObjectID]): A list of the object IDs whose
|
|
values should be retrieved.
|
|
"""
|
|
self.plasma_client.fetch([object_id.id() for object_id in object_ids])
|
|
# We currently pass in a timeout of one second.
|
|
unready_ids = object_ids
|
|
while len(unready_ids) > 0:
|
|
results = numbuf.retrieve_list([object_id.id() for object_id in object_ids], self.plasma_client.conn, 1000)
|
|
unready_ids = [object_id for (object_id, val) in results if val is None]
|
|
# This would be a natural place to issue a command to reconstruct some of
|
|
# the objects.
|
|
# Unwrap the object from the list (it was wrapped put_object).
|
|
assert len(results) == len(object_ids)
|
|
for i in range(len(results)):
|
|
assert results[i][0] == object_ids[i].id()
|
|
return [result[1][0] for result in results]
|
|
|
|
def submit_task(self, function_id, func_name, args):
|
|
"""Submit a remote task to the scheduler.
|
|
|
|
Tell the scheduler to schedule the execution of the function with name
|
|
func_name with arguments args. Retrieve object IDs for the outputs of
|
|
the function from the scheduler and immediately return them.
|
|
|
|
Args:
|
|
func_name (str): The name of the function to be executed.
|
|
args (List[Any]): The arguments to pass into the function. Arguments can
|
|
be object IDs or they can be values. If they are values, they
|
|
must be serializable objecs.
|
|
"""
|
|
with log_span("ray:submit_task", worker=self):
|
|
check_main_thread()
|
|
# Put large or complex arguments that are passed by value in the object
|
|
# store first.
|
|
args_for_photon = []
|
|
for arg in args:
|
|
if isinstance(arg, photon.ObjectID):
|
|
args_for_photon.append(arg)
|
|
elif photon.check_simple_value(arg):
|
|
args_for_photon.append(arg)
|
|
else:
|
|
args_for_photon.append(put(arg))
|
|
|
|
# Submit the task to Photon.
|
|
task = photon.Task(photon.ObjectID(function_id.id()),
|
|
args_for_photon,
|
|
self.num_return_vals[function_id.id()],
|
|
self.current_task_id,
|
|
self.task_index)
|
|
# Increment the worker's task index to track how many tasks have been
|
|
# submitted by the current task so far.
|
|
self.task_index += 1
|
|
self.photon_client.submit(task)
|
|
|
|
return task.returns()
|
|
|
|
def run_function_on_all_workers(self, function):
|
|
"""Run arbitrary code on all of the workers.
|
|
|
|
This function will first be run on the driver, and then it will be exported
|
|
to all of the workers to be run. It will also be run on any new workers that
|
|
register later. If ray.init has not been called yet, then cache the function
|
|
and export it later.
|
|
|
|
Args:
|
|
function (Callable): The function to run on all of the workers. It should
|
|
not take any arguments. If it returns anything, its return values will
|
|
not be used.
|
|
"""
|
|
check_main_thread()
|
|
if self.mode not in [None, SCRIPT_MODE, SILENT_MODE, PYTHON_MODE]:
|
|
raise Exception("run_function_on_all_workers can only be called on a driver.")
|
|
# If ray.init has not been called yet, then cache the function and export it
|
|
# when connect is called. Otherwise, run the function on all workers.
|
|
if self.mode is None:
|
|
self.cached_functions_to_run.append(function)
|
|
else:
|
|
function_to_run_id = random_string()
|
|
key = "FunctionsToRun:{}".format(function_to_run_id)
|
|
# First run the function on the driver. Pass in the number of workers on
|
|
# this node that have already started executing this remote function,
|
|
# and increment that value. Subtract 1 so that the counter starts at 0.
|
|
counter = self.redis_client.hincrby(self.node_ip_address, key, 1) - 1
|
|
function({"counter": counter})
|
|
# Run the function on all workers.
|
|
self.redis_client.hmset(key, {"function_id": function_to_run_id,
|
|
"function": pickling.dumps(function)})
|
|
self.redis_client.rpush("Exports", key)
|
|
self.driver_export_counter += 1
|
|
|
|
global_worker = Worker()
|
|
"""Worker: The global Worker object for this worker process.
|
|
|
|
We use a global Worker object to ensure that there is a single worker object
|
|
per worker process.
|
|
"""
|
|
|
|
env = RayEnvironmentVariables()
|
|
"""RayEnvironmentVariables: The environment variables that are shared by tasks.
|
|
|
|
Each worker process has its own RayEnvironmentVariables object, and these
|
|
objects should be the same in all workers. This is used for storing variables
|
|
that are not serializable but must be used by remote tasks. In addition, it is
|
|
used to reinitialize these variables after they are used so that changes to
|
|
their state made by one task do not affect other tasks.
|
|
"""
|
|
|
|
class RayConnectionError(Exception):
|
|
pass
|
|
|
|
def check_main_thread():
|
|
"""Check that we are currently on the main thread.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if this is called on a thread other than
|
|
the main thread.
|
|
"""
|
|
if threading.current_thread().getName() != "MainThread":
|
|
raise Exception("The Ray methods are not thread safe and must be called from the main thread. This method was called from thread {}.".format(threading.current_thread().getName()))
|
|
|
|
def check_connected(worker=global_worker):
|
|
"""Check if the worker is connected.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if the worker is not connected.
|
|
"""
|
|
if not worker.connected:
|
|
raise RayConnectionError("This command cannot be called before Ray has been started. You can start Ray with 'ray.init(num_workers=10)'.")
|
|
|
|
def print_failed_task(task_status):
|
|
"""Print information about failed tasks.
|
|
|
|
Args:
|
|
task_status (Dict): A dictionary containing the name, operationid, and
|
|
error message for a failed task.
|
|
"""
|
|
print("""
|
|
Error: Task failed
|
|
Function Name: {}
|
|
Task ID: {}
|
|
Error Message: \n{}
|
|
""".format(task_status["function_name"], task_status["operationid"], task_status["error_message"]))
|
|
|
|
def error_info(worker=global_worker):
|
|
"""Return information about failed tasks."""
|
|
check_connected(worker)
|
|
check_main_thread()
|
|
result = {b"TaskError": [],
|
|
b"RemoteFunctionImportError": [],
|
|
b"EnvironmentVariableImportError": [],
|
|
b"EnvironmentVariableReinitializeError": [],
|
|
b"FunctionToRunError": [],
|
|
b"GenericWarning": [],
|
|
}
|
|
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
|
|
for error_key in error_keys:
|
|
error_type = error_key.split(b":", 1)[0]
|
|
error_contents = worker.redis_client.hgetall(error_key)
|
|
result[error_type].append(error_contents)
|
|
|
|
return result
|
|
|
|
def initialize_numbuf(worker=global_worker):
|
|
"""Initialize the serialization library.
|
|
|
|
This defines a custom serializer for object IDs and also tells numbuf to
|
|
serialize several exception classes that we define for error handling.
|
|
"""
|
|
# Define a custom serializer and deserializer for handling Object IDs.
|
|
def objectid_custom_serializer(obj):
|
|
class_identifier = serialization.class_identifier(type(obj))
|
|
contained_objectids.append(obj)
|
|
return obj.id()
|
|
def objectid_custom_deserializer(serialized_obj):
|
|
return photon.ObjectID(serialized_obj)
|
|
serialization.add_class_to_whitelist(photon.ObjectID, pickle=False, custom_serializer=objectid_custom_serializer, custom_deserializer=objectid_custom_deserializer)
|
|
|
|
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
|
|
# These should only be called on the driver because register_class will
|
|
# export the class to all of the workers.
|
|
register_class(RayTaskError)
|
|
register_class(RayGetError)
|
|
register_class(RayGetArgumentError)
|
|
|
|
def get_address_info_from_redis_helper(redis_address, node_ip_address):
|
|
redis_host, redis_port = redis_address.split(":")
|
|
# For this command to work, some other client (on the same machine as Redis)
|
|
# must have run "CONFIG SET protected-mode no".
|
|
redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
|
|
# The client table prefix must be kept in sync with the file
|
|
# "src/common/redis_module/ray_redis_module.c" where it is defined.
|
|
REDIS_CLIENT_TABLE_PREFIX = "CL:"
|
|
client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX))
|
|
# Filter to clients on the same node and do some basic checking.
|
|
plasma_managers = []
|
|
local_schedulers = []
|
|
for key in client_keys:
|
|
info = redis_client.hgetall(key)
|
|
assert b"ray_client_id" in info
|
|
assert b"node_ip_address" in info
|
|
assert b"client_type" in info
|
|
if info[b"node_ip_address"].decode("ascii") == node_ip_address:
|
|
if info[b"client_type"].decode("ascii") == "plasma_manager":
|
|
plasma_managers.append(info)
|
|
elif info[b"client_type"].decode("ascii") == "photon":
|
|
local_schedulers.append(info)
|
|
# Make sure that we got at one plasma manager and local scheduler.
|
|
assert len(plasma_managers) >= 1
|
|
assert len(local_schedulers) >= 1
|
|
# Build the address information.
|
|
object_store_addresses = []
|
|
for manager in plasma_managers:
|
|
address = manager[b"address"].decode("ascii")
|
|
port = services.get_port(address)
|
|
object_store_addresses.append(
|
|
services.ObjectStoreAddress(
|
|
name=manager[b"store_socket_name"].decode("ascii"),
|
|
manager_name=manager[b"manager_socket_name"].decode("ascii"),
|
|
manager_port=port
|
|
)
|
|
)
|
|
scheduler_names = [scheduler[b"local_scheduler_socket_name"].decode("ascii")
|
|
for scheduler in local_schedulers]
|
|
client_info = {"node_ip_address": node_ip_address,
|
|
"redis_address": redis_address,
|
|
"object_store_addresses": object_store_addresses,
|
|
"local_scheduler_socket_names": scheduler_names,
|
|
}
|
|
return client_info
|
|
|
|
def get_address_info_from_redis(redis_address, node_ip_address, num_retries=5):
|
|
counter = 0
|
|
while True:
|
|
try:
|
|
return get_address_info_from_redis_helper(redis_address, node_ip_address)
|
|
except Exception as e:
|
|
if counter == num_retries:
|
|
raise
|
|
# Some of the information may not be in Redis yet, so wait a little bit.
|
|
print("Some processes that the driver needs to connect to have not registered with Redis, so retrying.")
|
|
time.sleep(1)
|
|
counter += 1
|
|
|
|
def _init(address_info=None, start_ray_local=False, object_id_seed=None,
|
|
num_workers=None, num_local_schedulers=None,
|
|
driver_mode=SCRIPT_MODE):
|
|
"""Helper method to connect to an existing Ray cluster or start a new one.
|
|
|
|
This method handles two cases. Either a Ray cluster already exists and we
|
|
just attach this driver to it, or we start all of the processes associated
|
|
with a Ray cluster and attach to the newly started cluster.
|
|
|
|
Args:
|
|
address_info (dict): A dictionary with address information for processes in
|
|
a partially-started Ray cluster. If start_ray_local=True, any processes
|
|
not in this dictionary will be started. If provided, address_info will be
|
|
modified to include processes that are newly started.
|
|
start_ray_local (bool): If True then this will start any processes not
|
|
already in address_info, including Redis, a global scheduler, local
|
|
scheduler(s), object store(s), and worker(s). It will also kill these
|
|
processes when Python exits. If False, this will attach to an existing
|
|
Ray cluster.
|
|
object_id_seed (int): Used to seed the deterministic generation of object
|
|
IDs. The same value can be used across multiple runs of the same job in
|
|
order to generate the object IDs in a consistent manner. However, the same
|
|
ID should not be used for different jobs.
|
|
num_workers (int): The number of workers to start. This is only provided if
|
|
start_ray_local is True.
|
|
num_local_schedulers (int): The number of local schedulers to start. This is
|
|
only provided if start_ray_local is True.
|
|
driver_mode (bool): The mode in which to start the driver. This should be
|
|
one of ray.SCRIPT_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE.
|
|
|
|
Returns:
|
|
Address information about the started processes.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if an inappropriate combination of
|
|
arguments is passed in.
|
|
"""
|
|
check_main_thread()
|
|
if driver_mode not in [SCRIPT_MODE, PYTHON_MODE, SILENT_MODE]:
|
|
raise Exception("Driver_mode must be in [ray.SCRIPT_MODE, ray.PYTHON_MODE, ray.SILENT_MODE].")
|
|
|
|
# Get addresses of existing services.
|
|
if address_info is None:
|
|
address_info = {}
|
|
else:
|
|
assert isinstance(address_info, dict)
|
|
node_ip_address = address_info.get("node_ip_address")
|
|
redis_address = address_info.get("redis_address")
|
|
|
|
# Start any services that do not yet exist.
|
|
if driver_mode == PYTHON_MODE:
|
|
# If starting Ray in PYTHON_MODE, don't start any other processes.
|
|
pass
|
|
elif start_ray_local:
|
|
# In this case, we launch a scheduler, a new object store, and some workers,
|
|
# and we connect to them. We do not launch any processes that are already
|
|
# registered in address_info.
|
|
# Use the address 127.0.0.1 in local mode.
|
|
node_ip_address = "127.0.0.1" if node_ip_address is None else node_ip_address
|
|
# Use 1 worker if num_workers is not provided.
|
|
num_workers = 10 if num_workers is None else num_workers
|
|
# Use 1 local scheduler if num_local_schedulers is not provided. If
|
|
# existing local schedulers are provided, use that count as
|
|
# num_local_schedulers.
|
|
local_schedulers = address_info.get("local_scheduler_socket_names", [])
|
|
if num_local_schedulers is None:
|
|
if len(local_schedulers) > 0:
|
|
num_local_schedulers = len(local_schedulers)
|
|
else:
|
|
num_local_schedulers = 1
|
|
# Start the scheduler, object store, and some workers. These will be killed
|
|
# by the call to cleanup(), which happens when the Python script exits.
|
|
address_info = services.start_ray_local(address_info=address_info,
|
|
node_ip_address=node_ip_address,
|
|
num_workers=num_workers,
|
|
num_local_schedulers=num_local_schedulers)
|
|
else:
|
|
if redis_address is None:
|
|
raise Exception("If start_ray_local=False, then redis_address must be provided.")
|
|
if num_workers is not None:
|
|
raise Exception("If start_ray_local=False, then num_workers must not be provided.")
|
|
if num_local_schedulers is not None:
|
|
raise Exception("If start_ray_local=False, then num_local_schedulers must not be provided.")
|
|
# Get the node IP address if one is not provided.
|
|
if node_ip_address is None:
|
|
node_ip_address = services.get_node_ip_address(redis_address)
|
|
# Get the address info of the processes to connect to from Redis.
|
|
address_info = get_address_info_from_redis(redis_address, node_ip_address)
|
|
|
|
# Connect this driver to Redis, the object store, and the local scheduler.
|
|
# Choose the first object store and local scheduler if there are multiple.
|
|
# The corresponding call to disconnect will happen in the call to cleanup()
|
|
# when the Python script exits.
|
|
if driver_mode == PYTHON_MODE:
|
|
driver_address_info = {}
|
|
else:
|
|
driver_address_info = {
|
|
"node_ip_address": node_ip_address,
|
|
"redis_address": address_info["redis_address"],
|
|
"store_socket_name": address_info["object_store_addresses"][0].name,
|
|
"manager_socket_name": address_info["object_store_addresses"][0].manager_name,
|
|
"local_scheduler_socket_name": address_info["local_scheduler_socket_names"][0],
|
|
}
|
|
connect(driver_address_info, object_id_seed=object_id_seed, mode=driver_mode, worker=global_worker)
|
|
return address_info
|
|
|
|
def init(redis_address=None, node_ip_address=None, object_id_seed=None,
|
|
num_workers=None, driver_mode=SCRIPT_MODE):
|
|
"""Either connect to an existing Ray cluster or start one and connect to it.
|
|
|
|
This method handles two cases. Either a Ray cluster already exists and we
|
|
just attach this driver to it, or we start all of the processes associated
|
|
with a Ray cluster and attach to the newly started cluster.
|
|
|
|
Args:
|
|
node_ip_address (str): The IP address of the node that we are on.
|
|
redis_address (str): The address of the Redis server to connect to. If this
|
|
address is not provided, then this command will start Redis, a global
|
|
scheduler, a local scheduler, a plasma store, a plasma manager, and some
|
|
workers. It will also kill these processes when Python exits.
|
|
object_id_seed (int): Used to seed the deterministic generation of object
|
|
IDs. The same value can be used across multiple runs of the same job in
|
|
order to generate the object IDs in a consistent manner. However, the same
|
|
ID should not be used for different jobs.
|
|
num_workers (int): The number of workers to start. This is only provided if
|
|
redis_address is not provided.
|
|
driver_mode (bool): The mode in which to start the driver. This should be
|
|
one of ray.SCRIPT_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE.
|
|
|
|
Returns:
|
|
Address information about the started processes.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if an inappropriate combination of
|
|
arguments is passed in.
|
|
"""
|
|
info = {
|
|
"node_ip_address": node_ip_address,
|
|
"redis_address": redis_address,
|
|
}
|
|
return _init(address_info=info, start_ray_local=(redis_address is None),
|
|
num_workers=num_workers, driver_mode=driver_mode)
|
|
|
|
def cleanup(worker=global_worker):
|
|
"""Disconnect the driver, and terminate any processes started in init.
|
|
|
|
This will automatically run at the end when a Python process that uses Ray
|
|
exits. It is ok to run this twice in a row. Note that we manually call
|
|
services.cleanup() in the tests because we need to start and stop many
|
|
clusters in the tests, but the import and exit only happen once.
|
|
"""
|
|
disconnect(worker)
|
|
worker.set_mode(None)
|
|
worker.driver_export_counter = 0
|
|
worker.worker_import_counter = 0
|
|
if hasattr(worker, "plasma_client"):
|
|
worker.plasma_client.shutdown()
|
|
services.cleanup()
|
|
|
|
atexit.register(cleanup)
|
|
|
|
def print_error_messages(worker):
|
|
"""Print error messages in the background on the driver.
|
|
|
|
This runs in a separate thread on the driver and prints error messages in the
|
|
background.
|
|
"""
|
|
# TODO(rkn): All error messages should have a "component" field indicating
|
|
# which process the error came from (e.g., a worker or a plasma store).
|
|
# Currently all error messages come from workers.
|
|
|
|
helpful_message = """
|
|
You can inspect errors by running
|
|
|
|
ray.error_info()
|
|
|
|
If this driver is hanging, start a new one with
|
|
|
|
ray.init(redis_address="{}")
|
|
""".format(worker.redis_address)
|
|
|
|
worker.error_message_pubsub_client = worker.redis_client.pubsub()
|
|
# Exports that are published after the call to
|
|
# error_message_pubsub_client.psubscribe and before the call to
|
|
# error_message_pubsub_client.listen will still be processed in the loop.
|
|
worker.error_message_pubsub_client.psubscribe("__keyspace@0__:ErrorKeys")
|
|
num_errors_printed = 0
|
|
|
|
# Get the exports that occurred before the call to psubscribe.
|
|
with worker.lock:
|
|
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
|
|
for error_key in error_keys:
|
|
error_message = worker.redis_client.hget(error_key, "message").decode("ascii")
|
|
print(error_message)
|
|
print(helpful_message)
|
|
num_errors_printed += 1
|
|
|
|
try:
|
|
for msg in worker.error_message_pubsub_client.listen():
|
|
with worker.lock:
|
|
for error_key in worker.redis_client.lrange("ErrorKeys", num_errors_printed, -1):
|
|
error_message = worker.redis_client.hget(error_key, "message").decode("ascii")
|
|
print(error_message)
|
|
print(helpful_message)
|
|
num_errors_printed += 1
|
|
except redis.ConnectionError:
|
|
# When Redis terminates the listen call will throw a ConnectionError, which
|
|
# we catch here.
|
|
pass
|
|
|
|
def fetch_and_register_remote_function(key, worker=global_worker):
|
|
"""Import a remote function."""
|
|
function_id_str, function_name, serialized_function, num_return_vals, module, function_export_counter = worker.redis_client.hmget(key, ["function_id", "name", "function", "num_return_vals", "module", "function_export_counter"])
|
|
function_id = photon.ObjectID(function_id_str)
|
|
function_name = function_name.decode("ascii")
|
|
num_return_vals = int(num_return_vals)
|
|
module = module.decode("ascii")
|
|
function_export_counter = int(function_export_counter)
|
|
|
|
worker.function_names[function_id.id()] = function_name
|
|
worker.num_return_vals[function_id.id()] = num_return_vals
|
|
worker.function_export_counters[function_id.id()] = function_export_counter
|
|
# This is a placeholder in case the function can't be unpickled. This will be
|
|
# overwritten if the function is unpickled successfully.
|
|
def f():
|
|
raise Exception("This function was not imported properly.")
|
|
worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals, function_id=function_id)(lambda *xs: f())
|
|
|
|
try:
|
|
function = pickling.loads(serialized_function)
|
|
except:
|
|
# If an exception was thrown when the remote function was imported, we
|
|
# record the traceback and notify the scheduler of the failure.
|
|
traceback_str = format_error_message(traceback.format_exc())
|
|
# Log the error message.
|
|
error_key = "RemoteFunctionImportError:{}".format(function_id.id())
|
|
worker.redis_client.hmset(error_key, {"function_id": function_id.id(),
|
|
"function_name": function_name,
|
|
"message": traceback_str})
|
|
worker.redis_client.rpush("ErrorKeys", error_key)
|
|
else:
|
|
# TODO(rkn): Why is the below line necessary?
|
|
function.__module__ = module
|
|
worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals, function_id=function_id)(function)
|
|
# Add the function to the function table.
|
|
worker.redis_client.rpush("FunctionTable:{}".format(function_id.id()), worker.worker_id)
|
|
|
|
def fetch_and_register_environment_variable(key, worker=global_worker):
|
|
"""Import an environment variable."""
|
|
environment_variable_name, serialized_initializer, serialized_reinitializer = worker.redis_client.hmget(key, ["name", "initializer", "reinitializer"])
|
|
environment_variable_name = environment_variable_name.decode("ascii")
|
|
try:
|
|
initializer = pickling.loads(serialized_initializer)
|
|
reinitializer = pickling.loads(serialized_reinitializer)
|
|
env.__setattr__(environment_variable_name, EnvironmentVariable(initializer, reinitializer))
|
|
except:
|
|
# If an exception was thrown when the environment variable was imported, we
|
|
# record the traceback and notify the scheduler of the failure.
|
|
traceback_str = format_error_message(traceback.format_exc())
|
|
# Log the error message.
|
|
error_key = "EnvironmentVariableImportError:{}".format(random_string())
|
|
worker.redis_client.hmset(error_key, {"name": environment_variable_name,
|
|
"message": traceback_str})
|
|
worker.redis_client.rpush("ErrorKeys", error_key)
|
|
|
|
def fetch_and_execute_function_to_run(key, worker=global_worker):
|
|
"""Run on arbitrary function on the worker."""
|
|
serialized_function, = worker.redis_client.hmget(key, ["function"])
|
|
# Get the number of workers on this node that have already started executing
|
|
# this remote function, and increment that value. Subtract 1 so the counter
|
|
# starts at 0.
|
|
counter = worker.redis_client.hincrby(worker.node_ip_address, key, 1) - 1
|
|
try:
|
|
# Deserialize the function.
|
|
function = pickling.loads(serialized_function)
|
|
# Run the function.
|
|
function({"counter": counter})
|
|
except:
|
|
# If an exception was thrown when the function was run, we record the
|
|
# traceback and notify the scheduler of the failure.
|
|
traceback_str = traceback.format_exc()
|
|
# Log the error message.
|
|
name = function.__name__ if "function" in locals() and hasattr(function, "__name__") else ""
|
|
error_key = "FunctionToRunError:{}".format(random_string())
|
|
worker.redis_client.hmset(error_key, {"name": name,
|
|
"message": traceback_str})
|
|
worker.redis_client.rpush("ErrorKeys", error_key)
|
|
|
|
def import_thread(worker):
|
|
worker.import_pubsub_client = worker.redis_client.pubsub()
|
|
# Exports that are published after the call to import_pubsub_client.psubscribe
|
|
# and before the call to import_pubsub_client.listen will still be processed
|
|
# in the loop.
|
|
worker.import_pubsub_client.psubscribe("__keyspace@0__:Exports")
|
|
worker_info_key = "WorkerInfo:{}".format(worker.worker_id)
|
|
worker.redis_client.hset(worker_info_key, "export_counter", 0)
|
|
worker.worker_import_counter = 0
|
|
|
|
# Get the exports that occurred before the call to psubscribe.
|
|
with worker.lock:
|
|
export_keys = worker.redis_client.lrange("Exports", 0, -1)
|
|
for key in export_keys:
|
|
if key.startswith(b"RemoteFunction"):
|
|
fetch_and_register_remote_function(key, worker=worker)
|
|
elif key.startswith(b"EnvironmentVariables"):
|
|
fetch_and_register_environment_variable(key, worker=worker)
|
|
elif key.startswith(b"FunctionsToRun"):
|
|
fetch_and_execute_function_to_run(key, worker=worker)
|
|
else:
|
|
raise Exception("This code should be unreachable.")
|
|
worker.redis_client.hincrby(worker_info_key, "export_counter", 1)
|
|
worker.worker_import_counter += 1
|
|
|
|
for msg in worker.import_pubsub_client.listen():
|
|
with worker.lock:
|
|
if msg["type"] == "psubscribe":
|
|
continue
|
|
assert msg["data"] == b"rpush"
|
|
num_imports = worker.redis_client.llen("Exports")
|
|
assert num_imports >= worker.worker_import_counter
|
|
for i in range(worker.worker_import_counter, num_imports):
|
|
key = worker.redis_client.lindex("Exports", i)
|
|
if key.startswith(b"RemoteFunction"):
|
|
with log_span("ray:import_remote_function", worker=worker):
|
|
fetch_and_register_remote_function(key, worker=worker)
|
|
elif key.startswith(b"EnvironmentVariables"):
|
|
with log_span("ray:import_environment_variable", worker=worker):
|
|
fetch_and_register_environment_variable(key, worker=worker)
|
|
elif key.startswith(b"FunctionsToRun"):
|
|
with log_span("ray:import_function_to_run", worker=worker):
|
|
fetch_and_execute_function_to_run(key, worker=worker)
|
|
else:
|
|
raise Exception("This code should be unreachable.")
|
|
worker.redis_client.hincrby(worker_info_key, "export_counter", 1)
|
|
worker.worker_import_counter += 1
|
|
|
|
def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker):
|
|
"""Connect this worker to the local scheduler, to Plasma, and to Redis.
|
|
|
|
Args:
|
|
info (dict): A dictionary with address of the Redis server and the sockets
|
|
of the plasma store, plasma manager, and local scheduler.
|
|
mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, PYTHON_MODE,
|
|
and SILENT_MODE.
|
|
"""
|
|
check_main_thread()
|
|
# Do some basic checking to make sure we didn't call ray.init twice.
|
|
error_message = "Perhaps you called ray.init twice by accident?"
|
|
assert not worker.connected, error_message
|
|
assert worker.cached_functions_to_run is not None, error_message
|
|
assert worker.cached_remote_functions is not None, error_message
|
|
assert env._cached_environment_variables is not None, error_message
|
|
# Initialize some fields.
|
|
worker.worker_id = random_string()
|
|
worker.connected = True
|
|
worker.set_mode(mode)
|
|
# The worker.events field is used to aggregate logging information and display
|
|
# it in the web UI. Note that Python lists protected by the GIL, which is
|
|
# important because we will append to this field from multiple threads.
|
|
worker.events = []
|
|
# If running Ray in PYTHON_MODE, there is no need to create call create_worker
|
|
# or to start the worker service.
|
|
if mode == PYTHON_MODE:
|
|
return
|
|
# Set the node IP address.
|
|
worker.node_ip_address = info["node_ip_address"]
|
|
worker.redis_address = info["redis_address"]
|
|
# Create a Redis client.
|
|
redis_host, redis_port = info["redis_address"].split(":")
|
|
worker.redis_client = redis.StrictRedis(host=redis_host, port=int(redis_port))
|
|
worker.lock = threading.Lock()
|
|
# Create an object store client.
|
|
worker.plasma_client = plasma.PlasmaClient(info["store_socket_name"], info["manager_socket_name"])
|
|
# Create the local scheduler client.
|
|
worker.photon_client = photon.PhotonClient(info["local_scheduler_socket_name"])
|
|
# Register the worker with Redis.
|
|
if mode in [SCRIPT_MODE, SILENT_MODE]:
|
|
worker.redis_client.hmset(b"Drivers:" + worker.worker_id, {"node_ip_address": worker.node_ip_address})
|
|
elif mode == WORKER_MODE:
|
|
worker.redis_client.hmset(b"Workers:" + worker.worker_id, {"node_ip_address": worker.node_ip_address})
|
|
else:
|
|
raise Exception("This code should be unreachable.")
|
|
# If this is a driver, set the current task ID and set the task index to 0.
|
|
if mode in [SCRIPT_MODE, SILENT_MODE]:
|
|
# If the user provided an object_id_seed, then set the current task ID
|
|
# deterministically based on that seed (without altering the state of the
|
|
# user's random number generator). Otherwise, set the current task ID
|
|
# randomly to avoid object ID collisions.
|
|
numpy_state = np.random.get_state()
|
|
if object_id_seed is not None:
|
|
np.random.seed(object_id_seed)
|
|
else:
|
|
# Try to use true randomness.
|
|
np.random.seed(None)
|
|
worker.current_task_id = photon.ObjectID(np.random.bytes(20))
|
|
# Reset the state of the numpy random number generator.
|
|
np.random.set_state(numpy_state)
|
|
# Set other fields needed for computing task IDs.
|
|
worker.task_index = 0
|
|
worker.put_index = 0
|
|
# If this is a worker, then start a thread to import exports from the driver.
|
|
if mode == WORKER_MODE:
|
|
t = threading.Thread(target=import_thread, args=(worker,))
|
|
# Making the thread a daemon causes it to exit when the main thread exits.
|
|
t.daemon = True
|
|
t.start()
|
|
# If this is a driver running in SCRIPT_MODE, start a thread to print error
|
|
# messages asynchronously in the background. Ideally the scheduler would push
|
|
# messages to the driver's worker service, but we ran into bugs when trying to
|
|
# properly shutdown the driver's worker service, so we are temporarily using
|
|
# this implementation which constantly queries the scheduler for new error
|
|
# messages.
|
|
if mode == SCRIPT_MODE:
|
|
t = threading.Thread(target=print_error_messages, args=(worker,))
|
|
# Making the thread a daemon causes it to exit when the main thread exits.
|
|
t.daemon = True
|
|
t.start()
|
|
# Initialize the serialization library. This registers some classes, and so
|
|
# it must be run before we export all of the cached remote functions.
|
|
initialize_numbuf()
|
|
if mode in [SCRIPT_MODE, SILENT_MODE]:
|
|
# Add the directory containing the script that is running to the Python
|
|
# paths of the workers. Also add the current directory. Note that this
|
|
# assumes that the directory structures on the machines in the clusters are
|
|
# the same.
|
|
script_directory = os.path.abspath(os.path.dirname(sys.argv[0]))
|
|
current_directory = os.path.abspath(os.path.curdir)
|
|
worker.run_function_on_all_workers(lambda worker_info: sys.path.insert(1, script_directory))
|
|
worker.run_function_on_all_workers(lambda worker_info: sys.path.insert(1, current_directory))
|
|
# TODO(rkn): Here we first export functions to run, then environment
|
|
# variables, then remote functions. The order matters. For example, one of
|
|
# the functions to run may set the Python path, which is needed to import a
|
|
# module used to define an environment variable, which in turn is used
|
|
# inside a remote function. We may want to change the order to simply be the
|
|
# order in which the exports were defined on the driver. In addition, we
|
|
# will need to retain the ability to decide what the first few exports are
|
|
# (mostly to set the Python path). Additionally, note that the first exports
|
|
# to be defined on the driver will be the ones defined in separate modules
|
|
# that are imported by the driver.
|
|
# Export cached functions_to_run.
|
|
for function in worker.cached_functions_to_run:
|
|
worker.run_function_on_all_workers(function)
|
|
# Export cached environment variables to the workers.
|
|
for name, environment_variable in env._cached_environment_variables:
|
|
env.__setattr__(name, environment_variable)
|
|
# Export cached remote functions to the workers.
|
|
for function_id, func_name, func, num_return_vals in worker.cached_remote_functions:
|
|
export_remote_function(function_id, func_name, func, num_return_vals, worker)
|
|
worker.cached_functions_to_run = None
|
|
worker.cached_remote_functions = None
|
|
env._cached_environment_variables = None
|
|
|
|
def disconnect(worker=global_worker):
|
|
"""Disconnect this worker from the scheduler and object store."""
|
|
# Reset the list of cached remote functions so that if more remote functions
|
|
# are defined and then connect is called again, the remote functions will be
|
|
# exported. This is mostly relevant for the tests.
|
|
worker.connected = False
|
|
worker.cached_functions_to_run = []
|
|
worker.cached_remote_functions = []
|
|
env._cached_environment_variables = []
|
|
|
|
def register_class(cls, pickle=False, worker=global_worker):
|
|
"""Enable workers to serialize or deserialize objects of a particular class.
|
|
|
|
This method runs the register_class function defined below on every worker,
|
|
which will enable numbuf to properly serialize and deserialize objects of this
|
|
class.
|
|
|
|
Args:
|
|
cls (type): The class that numbuf should serialize.
|
|
pickle (bool): If False then objects of this class will be serialized by
|
|
turning their __dict__ fields into a dictionary. If True, then objects
|
|
of this class will be serialized using pickle.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if pickle=False and the class cannot be
|
|
efficiently serialized by Ray.
|
|
"""
|
|
# If the worker is not a driver, then return. We do this so that Python
|
|
# modules can register classes and these modules can be imported on workers
|
|
# without any trouble.
|
|
if worker.mode == WORKER_MODE:
|
|
return
|
|
# Raise an exception if cls cannot be serialized efficiently by Ray.
|
|
if not pickle:
|
|
serialization.check_serializable(cls)
|
|
def register_class_for_serialization(worker_info):
|
|
serialization.add_class_to_whitelist(cls, pickle=pickle)
|
|
worker.run_function_on_all_workers(register_class_for_serialization)
|
|
|
|
class RayLogSpan(object):
|
|
"""An object used to enable logging a span of events with a with statement.
|
|
|
|
Attributes:
|
|
event_type (str): The type of the event being logged.
|
|
contents: Additional information to log.
|
|
"""
|
|
def __init__(self, event_type, contents=None, worker=global_worker):
|
|
"""Initialize a RayLogSpan object."""
|
|
self.event_type = event_type
|
|
self.contents = contents
|
|
self.worker = worker
|
|
|
|
def __enter__(self):
|
|
"""Log the beginning of a span event."""
|
|
log(event_type=self.event_type,
|
|
contents=self.contents,
|
|
kind=LOG_SPAN_START,
|
|
worker=self.worker)
|
|
|
|
def __exit__(self, type, value, tb):
|
|
"""Log the end of a span event. Log any exception that occurred."""
|
|
if type is None:
|
|
log(event_type=self.event_type, kind=LOG_SPAN_END, worker=self.worker)
|
|
else:
|
|
log(event_type=self.event_type,
|
|
contents={"type": str(type),
|
|
"value": value,
|
|
"traceback": traceback.format_exc()},
|
|
kind=LOG_SPAN_END,
|
|
worker=self.worker)
|
|
|
|
def log_span(event_type, contents=None, worker=global_worker):
|
|
return RayLogSpan(event_type, contents=contents, worker=worker)
|
|
|
|
def log_event(event_type, contents=None, worker=global_worker):
|
|
log(event_type, kind=LOG_POINT, contents=contents, worker=worker)
|
|
|
|
def log(event_type, kind, contents=None, worker=global_worker):
|
|
"""Log an event to the global state store.
|
|
|
|
This adds the event to a buffer of events locally. The buffer can be flushed
|
|
and written to the global state store by calling flush_log().
|
|
|
|
Args:
|
|
event_type (str): The type of the event.
|
|
contents: More general data to store with the event.
|
|
kind (int): Either LOG_POINT, LOG_SPAN_START, or LOG_SPAN_END. This is
|
|
LOG_POINT if the event being logged happens at a single point in time. It
|
|
is LOG_SPAN_START if we are starting to log a span of time, and it is
|
|
LOG_SPAN_END if we are finishing logging a span of time.
|
|
"""
|
|
# TODO(rkn): This code currently takes around half a microsecond. Since we
|
|
# call it tens of times per task, this adds up. We will need to redo the
|
|
# logging code, perhaps in C.
|
|
contents = {} if contents is None else contents
|
|
assert isinstance(contents, dict)
|
|
# Make sure all of the keys and values in the dictionary are strings.
|
|
contents = {str(k): str(v) for k, v in contents.items()}
|
|
worker.events.append((time.time(), event_type, kind, contents))
|
|
|
|
def flush_log(worker=global_worker):
|
|
"""Send the logged worker events to the global state store."""
|
|
event_log_key = b"event_log:" + worker.worker_id + b":" + worker.current_task_id.id()
|
|
event_log_value = json.dumps(worker.events)
|
|
worker.photon_client.log_event(event_log_key, event_log_value)
|
|
worker.events = []
|
|
|
|
def get(object_ids, worker=global_worker):
|
|
"""Get a remote object or a list of remote objects from the object store.
|
|
|
|
This method blocks until the object corresponding to the object ID is available in
|
|
the local object store. If this object is not in the local object store, it
|
|
will be shipped from an object store that has it (once the object has been
|
|
created). If object_ids is a list, then the objects corresponding to each object
|
|
in the list will be returned.
|
|
|
|
Args:
|
|
object_ids: Object ID of the object to get or a list of object IDs to get.
|
|
|
|
Returns:
|
|
A Python object or a list of Python objects.
|
|
"""
|
|
check_connected(worker)
|
|
with log_span("ray:get", worker=worker):
|
|
check_main_thread()
|
|
|
|
if worker.mode == PYTHON_MODE:
|
|
# In PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid)
|
|
return object_ids
|
|
if isinstance(object_ids, list):
|
|
values = worker.get_object(object_ids)
|
|
for i, value in enumerate(values):
|
|
if isinstance(value, RayTaskError):
|
|
raise RayGetError(object_ids[i], value)
|
|
return values
|
|
else:
|
|
value = worker.get_object([object_ids])[0]
|
|
if isinstance(value, RayTaskError):
|
|
# If the result is a RayTaskError, then the task that created this object
|
|
# failed, and we should propagate the error message here.
|
|
raise RayGetError(object_ids, value)
|
|
return value
|
|
|
|
def put(value, worker=global_worker):
|
|
"""Store an object in the object store.
|
|
|
|
Args:
|
|
value (serializable object): The Python object to be stored.
|
|
|
|
Returns:
|
|
The object ID assigned to this value.
|
|
"""
|
|
check_connected(worker)
|
|
with log_span("ray:put", worker=worker):
|
|
check_main_thread()
|
|
|
|
if worker.mode == PYTHON_MODE:
|
|
# In PYTHON_MODE, ray.put is the identity operation
|
|
return value
|
|
object_id = photon.compute_put_id(worker.current_task_id, worker.put_index)
|
|
worker.put_object(object_id, value)
|
|
worker.put_index += 1
|
|
return object_id
|
|
|
|
def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
|
"""Return a list of IDs that are ready and a list of IDs that are not ready.
|
|
|
|
If timeout is set, the function returns either when the requested number of
|
|
IDs are ready or when the timeout is reached, whichever occurs first. If it is
|
|
not set, the function simply waits until that number of objects is ready and
|
|
returns that exact number of objectids.
|
|
|
|
This method returns two lists. The first list consists of object IDs that
|
|
correspond to objects that are stored in the object store. The second list
|
|
corresponds to the rest of the object IDs (which may or may not be ready).
|
|
|
|
Args:
|
|
object_ids (List[ObjectID]): List of object IDs for objects that may
|
|
or may not be ready.
|
|
num_returns (int): The number of object IDs that should be returned.
|
|
timeout (int): The maximum amount of time in milliseconds to wait before
|
|
returning.
|
|
|
|
Returns:
|
|
A list of object IDs that are ready and a list of the remaining object IDs.
|
|
"""
|
|
check_connected(worker)
|
|
with log_span("ray:wait", worker=worker):
|
|
check_main_thread()
|
|
object_id_strs = [object_id.id() for object_id in object_ids]
|
|
timeout = timeout if timeout is not None else 2 ** 30
|
|
ready_ids, remaining_ids = worker.plasma_client.wait(object_id_strs, timeout, num_returns)
|
|
ready_ids = [photon.ObjectID(object_id) for object_id in ready_ids]
|
|
remaining_ids = [photon.ObjectID(object_id) for object_id in remaining_ids]
|
|
return ready_ids, remaining_ids
|
|
|
|
def wait_for_valid_import_counter(function_id, timeout=5, worker=global_worker):
|
|
"""Wait until this worker has imported enough to execute the function.
|
|
|
|
This method will simply loop until the import thread has imported enough of
|
|
the exports to execute the function. If we spend too long in this loop, that
|
|
may indicate a problem somewhere and we will push an error message to the
|
|
user.
|
|
|
|
Args:
|
|
function_id (str): The ID of the function that we want to execute.
|
|
"""
|
|
start_time = time.time()
|
|
# Only send the warning once.
|
|
warning_sent = False
|
|
num_warnings_sent = 0
|
|
while True:
|
|
with worker.lock:
|
|
if function_id.id() in worker.functions and (worker.function_export_counters[function_id.id()] <= worker.worker_import_counter):
|
|
break
|
|
if time.time() - start_time > timeout * (num_warnings_sent + 1):
|
|
if function_id.id() not in worker.functions:
|
|
warning_message = "This worker was asked to execute a function that it does not have registered. You may have to restart Ray."
|
|
else:
|
|
warning_message = "This worker's import counter is too small."
|
|
if not warning_sent:
|
|
push_warning_to_user(warning_message, worker=worker)
|
|
warning_sent = True
|
|
time.sleep(0.001)
|
|
|
|
def format_error_message(exception_message, task_exception=False):
|
|
"""Improve the formatting of an exception thrown by a remote function.
|
|
|
|
This method takes a traceback from an exception and makes it nicer by
|
|
removing a few uninformative lines and adding some space to indent the
|
|
remaining lines nicely.
|
|
|
|
Args:
|
|
exception_message (str): A message generated by traceback.format_exc().
|
|
|
|
Returns:
|
|
A string of the formatted exception message.
|
|
"""
|
|
lines = exception_message.split("\n")
|
|
if task_exception:
|
|
# For errors that occur inside of tasks, remove lines 1, 2, 3, and 4,
|
|
# which are always the same, they just contain information about the main
|
|
# loop.
|
|
lines = lines[0:1] + lines[5:]
|
|
return "\n".join(lines)
|
|
|
|
def main_loop(worker=global_worker):
|
|
"""The main loop a worker runs to receive and execute tasks.
|
|
|
|
This method is an infinite loop. It waits to receive commands from the
|
|
scheduler. A command may consist of a task to execute, a remote function to
|
|
import, an environment variable to import, or an order to terminate the worker
|
|
process. The worker executes the command, notifies the scheduler of any errors
|
|
that occurred while executing the command, and waits for the next command.
|
|
"""
|
|
|
|
def process_task(task): # wrapping these lines in a function should cause the local variables to go out of scope more quickly, which is useful for inspecting reference counts
|
|
"""Execute a task assigned to this worker.
|
|
|
|
This method deserializes a task from the scheduler, and attempts to execute
|
|
the task. If the task succeeds, the outputs are stored in the local object
|
|
store. If the task throws an exception, RayTaskError objects are stored in
|
|
the object store to represent the failed task (these will be retrieved by
|
|
calls to get or by subsequent tasks that use the outputs of this task).
|
|
After the task executes, the worker resets any environment variables that
|
|
were accessed by the task.
|
|
"""
|
|
try:
|
|
worker.current_task_id = task.task_id()
|
|
worker.task_index = 0
|
|
worker.put_index = 0
|
|
function_id = task.function_id()
|
|
args = task.arguments()
|
|
return_object_ids = task.returns()
|
|
function_name = worker.function_names[function_id.id()]
|
|
|
|
# Get task arguments from the object store.
|
|
with log_span("ray:task:get_arguments", worker=worker):
|
|
arguments = get_arguments_for_execution(worker.functions[function_id.id()], args, worker)
|
|
|
|
# Execute the task.
|
|
with log_span("ray:task:execute", worker=worker):
|
|
outputs = worker.functions[function_id.id()].executor(arguments)
|
|
|
|
# Store the outputs in the local object store.
|
|
with log_span("ray:task:store_outputs", worker=worker):
|
|
if len(return_object_ids) == 1:
|
|
outputs = (outputs,)
|
|
store_outputs_in_objstore(return_object_ids, outputs, worker)
|
|
except Exception as e:
|
|
# We determine whether the exception was caused by the call to
|
|
# get_arguments_for_execution or by the execution of the remote function
|
|
# or by the call to store_outputs_in_objstore. Depending on which case
|
|
# occurred, we format the error message differently.
|
|
# whether the variables "arguments" and "outputs" are defined.
|
|
if "arguments" in locals() and "outputs" not in locals():
|
|
# The error occurred during the task execution.
|
|
traceback_str = format_error_message(traceback.format_exc(), task_exception=True)
|
|
elif "arguments" in locals() and "outputs" in locals():
|
|
# The error occurred after the task executed.
|
|
traceback_str = format_error_message(traceback.format_exc())
|
|
else:
|
|
# The error occurred before the task execution.
|
|
traceback_str = None
|
|
failure_object = RayTaskError(function_name, e, traceback_str)
|
|
failure_objects = [failure_object for _ in range(len(return_object_ids))]
|
|
store_outputs_in_objstore(return_object_ids, failure_objects, worker)
|
|
# Log the error message.
|
|
error_key = "TaskError:{}".format(random_string())
|
|
worker.redis_client.hmset(error_key, {"function_id": function_id.id(),
|
|
"function_name": function_name,
|
|
"message": str(failure_object)})
|
|
worker.redis_client.rpush("ErrorKeys", error_key)
|
|
try:
|
|
# Reinitialize the values of environment variables that were used in the
|
|
# task above so that changes made to their state do not affect other tasks.
|
|
with log_span("ray:task:reinitialize_environment_variables", worker=worker):
|
|
env._reinitialize()
|
|
except Exception as e:
|
|
# The attempt to reinitialize the environment variables threw an
|
|
# exception. We record the traceback and notify the scheduler.
|
|
traceback_str = format_error_message(traceback.format_exc())
|
|
error_key = "EnvironmentVariableReinitializeError:{}".format(random_string())
|
|
worker.redis_client.hmset(error_key, {"task_id": "NOTIMPLEMENTED",
|
|
"function_id": function_id.id(),
|
|
"function_name": function_name,
|
|
"message": traceback_str})
|
|
worker.redis_client.rpush("ErrorKeys", error_key)
|
|
|
|
check_main_thread()
|
|
while True:
|
|
with log_span("ray:get_task", worker=worker):
|
|
task = worker.photon_client.get_task()
|
|
|
|
function_id = task.function_id()
|
|
# Check that the number of imports we have is at least as great as the
|
|
# export counter for the task. If not, wait until we have imported enough.
|
|
# We will push warnings to the user if we spend too long in this loop.
|
|
with log_span("ray:wait_for_import_counter", worker=worker):
|
|
wait_for_valid_import_counter(function_id, worker=worker)
|
|
|
|
# Execute the task.
|
|
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
|
|
# warning to the user if we are waiting too long to acquire the lock because
|
|
# that may indicate that the system is hanging, and it'd be good to know
|
|
# where the system is hanging.
|
|
log(event_type="ray:acquire_lock", kind=LOG_SPAN_START, worker=worker)
|
|
with worker.lock:
|
|
log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, worker=worker)
|
|
|
|
contents = {"function_name": worker.function_names[function_id.id()],
|
|
"task_id": task.task_id().hex()}
|
|
with log_span("ray:task", contents=contents, worker=worker):
|
|
process_task(task)
|
|
|
|
# Push all of the log events to the global state store.
|
|
flush_log()
|
|
|
|
def push_warning_to_user(message, worker=global_worker):
|
|
error_key = "GenericWarning:{}".format(random_string())
|
|
worker.redis_client.hmset(error_key, {"message": message})
|
|
worker.redis_client.rpush("ErrorKeys", error_key)
|
|
|
|
def _submit_task(function_id, func_name, args, worker=global_worker):
|
|
"""This is a wrapper around worker.submit_task.
|
|
|
|
We use this wrapper so that in the remote decorator, we can call _submit_task
|
|
instead of worker.submit_task. The difference is that when we attempt to
|
|
serialize remote functions, we don't attempt to serialize the worker object,
|
|
which cannot be serialized.
|
|
"""
|
|
return worker.submit_task(function_id, func_name, args)
|
|
|
|
def _mode(worker=global_worker):
|
|
"""This is a wrapper around worker.mode.
|
|
|
|
We use this wrapper so that in the remote decorator, we can call _mode()
|
|
instead of worker.mode. The difference is that when we attempt to serialize
|
|
remote functions, we don't attempt to serialize the worker object, which
|
|
cannot be serialized.
|
|
"""
|
|
return worker.mode
|
|
|
|
def _env():
|
|
"""Return the env object.
|
|
|
|
We use this wrapper because so that functions which use the env object can be
|
|
pickled.
|
|
"""
|
|
return env
|
|
|
|
def _export_environment_variable(name, environment_variable, worker=global_worker):
|
|
"""Export an environment variable to the workers.
|
|
|
|
This is only called by a driver.
|
|
|
|
Args:
|
|
name (str): The name of the variable to export.
|
|
environment_variable (EnvironmentVariable): The environment variable object
|
|
containing code for initializing and reinitializing the variable.
|
|
"""
|
|
check_main_thread()
|
|
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
|
|
raise Exception("_export_environment_variable can only be called on a driver.")
|
|
environment_variable_id = name
|
|
key = "EnvironmentVariables:{}".format(environment_variable_id)
|
|
worker.redis_client.hmset(key, {"name": name,
|
|
"initializer": pickling.dumps(environment_variable.initializer),
|
|
"reinitializer": pickling.dumps(environment_variable.reinitializer)})
|
|
worker.redis_client.rpush("Exports", key)
|
|
worker.driver_export_counter += 1
|
|
|
|
def export_remote_function(function_id, func_name, func, num_return_vals, worker=global_worker):
|
|
check_main_thread()
|
|
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
|
|
raise Exception("export_remote_function can only be called on a driver.")
|
|
key = "RemoteFunction:{}".format(function_id.id())
|
|
worker.num_return_vals[function_id.id()] = num_return_vals
|
|
pickled_func = pickling.dumps(func)
|
|
worker.redis_client.hmset(key, {"function_id": function_id.id(),
|
|
"name": func_name,
|
|
"module": func.__module__,
|
|
"function": pickled_func,
|
|
"num_return_vals": num_return_vals,
|
|
"function_export_counter": worker.driver_export_counter})
|
|
worker.redis_client.rpush("Exports", key)
|
|
worker.driver_export_counter += 1
|
|
|
|
def remote(*args, **kwargs):
|
|
"""This decorator is used to create remote functions.
|
|
|
|
Args:
|
|
num_return_vals (int): The number of object IDs that a call to this function
|
|
should return.
|
|
"""
|
|
worker = global_worker
|
|
def make_remote_decorator(num_return_vals, func_id=None):
|
|
def remote_decorator(func):
|
|
func_name = "{}.{}".format(func.__module__, func.__name__)
|
|
if func_id is None:
|
|
function_id = FunctionID((hashlib.sha256(func_name.encode("ascii")).digest())[:20])
|
|
else:
|
|
function_id = func_id
|
|
|
|
def func_call(*args, **kwargs):
|
|
"""This gets run immediately when a worker calls a remote function."""
|
|
check_connected()
|
|
check_main_thread()
|
|
args = list(args)
|
|
args.extend([kwargs[keyword] if keyword in kwargs else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments
|
|
if any([arg is funcsigs._empty for arg in args]):
|
|
raise Exception("Not enough arguments were provided to {}.".format(func_name))
|
|
if _mode() == PYTHON_MODE:
|
|
# In PYTHON_MODE, remote calls simply execute the function. We copy the
|
|
# arguments to prevent the function call from mutating them and to match
|
|
# the usual behavior of immutable remote objects.
|
|
try:
|
|
_env()._running_remote_function_locally = True
|
|
result = func(*copy.deepcopy(args))
|
|
finally:
|
|
_env()._reinitialize()
|
|
_env()._running_remote_function_locally = False
|
|
return result
|
|
objectids = _submit_task(function_id, func_name, args)
|
|
if len(objectids) == 1:
|
|
return objectids[0]
|
|
elif len(objectids) > 1:
|
|
return objectids
|
|
def func_executor(arguments):
|
|
"""This gets run when the remote function is executed."""
|
|
start_time = time.time()
|
|
result = func(*arguments)
|
|
end_time = time.time()
|
|
return result
|
|
def func_invoker(*args, **kwargs):
|
|
"""This is returned by the decorator and used to invoke the function."""
|
|
raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name))
|
|
func_invoker.remote = func_call
|
|
func_invoker.executor = func_executor
|
|
func_invoker.is_remote = True
|
|
func_name = "{}.{}".format(func.__module__, func.__name__)
|
|
func_invoker.func_name = func_name
|
|
if sys.version_info >= (3, 0):
|
|
func_invoker.__doc__ = func.__doc__
|
|
else:
|
|
func_invoker.func_doc = func.func_doc
|
|
|
|
sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.items()]
|
|
keyword_defaults = [(k, v.default) for k, v in sig_params]
|
|
has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params])
|
|
func_invoker.has_vararg_param = has_vararg_param
|
|
has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params])
|
|
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
|
|
|
|
# Everything ready - export the function
|
|
if worker.mode in [None, SCRIPT_MODE, SILENT_MODE]:
|
|
func_name_global_valid = func.__name__ in func.__globals__
|
|
func_name_global_value = func.__globals__.get(func.__name__)
|
|
# Set the function globally to make it refer to itself
|
|
func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable
|
|
try:
|
|
to_export = pickling.dumps((func, num_return_vals, func.__module__))
|
|
finally:
|
|
# Undo our changes
|
|
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
|
|
else: del func.__globals__[func.__name__]
|
|
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
|
|
export_remote_function(function_id, func_name, func, num_return_vals)
|
|
elif worker.mode is None:
|
|
worker.cached_remote_functions.append((function_id, func_name, func, num_return_vals))
|
|
return func_invoker
|
|
|
|
return remote_decorator
|
|
|
|
if _mode() == WORKER_MODE:
|
|
if "function_id" in kwargs:
|
|
num_return_vals = kwargs["num_return_vals"]
|
|
function_id = kwargs["function_id"]
|
|
return make_remote_decorator(num_return_vals, function_id)
|
|
|
|
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
|
# This is the case where the decorator is just @ray.remote.
|
|
num_return_vals = 1
|
|
func = args[0]
|
|
return make_remote_decorator(num_return_vals)(func)
|
|
else:
|
|
# This is the case where the decorator is something like
|
|
# @ray.remote(num_return_vals=2).
|
|
assert len(args) == 0 and "num_return_vals" in kwargs, "The @ray.remote decorator must be applied either with no arguments and no parentheses, for example '@ray.remote', or it must be applied with only the argument num_return_vals, like '@ray.remote(num_return_vals=2)'."
|
|
num_return_vals = kwargs["num_return_vals"]
|
|
assert not "function_id" in kwargs
|
|
return make_remote_decorator(num_return_vals)
|
|
|
|
def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, name):
|
|
"""Check if we support the signature of this function.
|
|
|
|
We currently do not allow remote functions to have **kwargs. We also do not
|
|
support keyword argumens in conjunction with a *args argument.
|
|
|
|
Args:
|
|
has_kwards_param (bool): True if the function being checked has a **kwargs
|
|
argument.
|
|
has_vararg_param (bool): True if the function being checked has a *args
|
|
argument.
|
|
keyword_defaults (List): A list of the default values for the arguments to
|
|
the function being checked.
|
|
name (str): The name of the function to check.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if the signature is not supported.
|
|
"""
|
|
# check if the user specified kwargs
|
|
if has_kwargs_param:
|
|
raise "Function {} has a **kwargs argument, which is currently not supported.".format(name)
|
|
# check if the user specified a variable number of arguments and any keyword arguments
|
|
if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]):
|
|
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name)
|
|
|
|
def get_arguments_for_execution(function, serialized_args, worker=global_worker):
|
|
"""Retrieve the arguments for the remote function.
|
|
|
|
This retrieves the values for the arguments to the remote function that were
|
|
passed in as object IDs. Argumens that were passed by value are not changed.
|
|
This is called by the worker that is executing the remote function.
|
|
|
|
Args:
|
|
function (Callable): The remote function whose arguments are being
|
|
retrieved.
|
|
serialized_args (List): The arguments to the function. These are either
|
|
strings representing serialized objects passed by value or they are
|
|
ObjectIDs.
|
|
|
|
Returns:
|
|
The retrieved arguments in addition to the arguments that were passed by
|
|
value.
|
|
|
|
Raises:
|
|
RayGetArgumentError: This exception is raised if a task that created one of
|
|
the arguments failed.
|
|
"""
|
|
arguments = []
|
|
for (i, arg) in enumerate(serialized_args):
|
|
if isinstance(arg, photon.ObjectID):
|
|
# get the object from the local object store
|
|
argument = worker.get_object([arg])[0]
|
|
if isinstance(argument, RayTaskError):
|
|
# If the result is a RayTaskError, then the task that created this
|
|
# object failed, and we should propagate the error message here.
|
|
raise RayGetArgumentError(function.__name__, i, arg, argument)
|
|
else:
|
|
# pass the argument by value
|
|
argument = arg
|
|
|
|
arguments.append(argument)
|
|
return arguments
|
|
|
|
def store_outputs_in_objstore(objectids, outputs, worker=global_worker):
|
|
"""Store the outputs of a remote function in the local object store.
|
|
|
|
This stores the values that were returned by a remote function in the local
|
|
object store. If any of the return values are object IDs, then these object
|
|
IDs are aliased with the object IDs that the scheduler assigned for the return
|
|
values. This is called by the worker that executes the remote function.
|
|
|
|
Note:
|
|
The arguments objectids and outputs should have the same length.
|
|
|
|
Args:
|
|
objectids (List[ObjectID]): The object IDs that were assigned to the
|
|
outputs of the remote function call.
|
|
outputs (Tuple): The value returned by the remote function. If the remote
|
|
function was supposed to only return one value, then its output was
|
|
wrapped in a tuple with one element prior to being passed into this
|
|
function.
|
|
"""
|
|
for i in range(len(objectids)):
|
|
if isinstance(outputs[i], photon.ObjectID):
|
|
raise Exception("This remote function returned an ObjectID as its {}th return value. This is not allowed.".format(i))
|
|
for i in range(len(objectids)):
|
|
worker.put_object(objectids[i], outputs[i])
|