Files
ray/python/ray/worker.py
T
Johann Schleier-Smith 6ad2b5d87a Add Redis port option to startup script (#232)
* specify redis address when starting head

* cleanup

* update starting cluster documentation

* Whitespace.

* Address Philipp's comments.

* Change redis_host -> redis_ip_address.
2017-01-31 00:28:00 -08:00

1790 lines
80 KiB
Python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import hashlib
import os
import sys
import time
import traceback
import copy
import funcsigs
import numpy as np
import colorama
import atexit
import random
import redis
import threading
import string
# Ray modules
import ray.pickling as pickling
import ray.serialization as serialization
import ray.services as services
import numbuf
import photon
import plasma
SCRIPT_MODE = 0
WORKER_MODE = 1
PYTHON_MODE = 2
SILENT_MODE = 3
LOG_POINT = 0
LOG_SPAN_START = 1
LOG_SPAN_END = 2
ERROR_KEY_PREFIX = b"Error:"
DRIVER_ID_LENGTH = 20
ERROR_ID_LENGTH = 20
def random_string():
return np.random.bytes(20)
def random_object_id():
return photon.ObjectID(random_string())
class FunctionID(object):
def __init__(self, function_id):
self.function_id = function_id
def id(self):
return self.function_id
contained_objectids = []
def numbuf_serialize(value):
"""This serializes a value and tracks the object IDs inside the value.
We also define a custom ObjectID serializer which also closes over the global
variable contained_objectids, and whenever the custom serializer is called, it
adds the releevant ObjectID to the list contained_objectids. The list
contained_objectids should be reset between calls to numbuf_serialize.
Args:
value: A Python object that will be serialized.
Returns:
The serialized object.
"""
assert len(contained_objectids) == 0, "This should be unreachable."
return numbuf.serialize_list([value])
class RayTaskError(Exception):
"""An object used internally to represent a task that threw an exception.
If a task throws an exception during execution, a RayTaskError is stored in
the object store for each of the task's outputs. When an object is retrieved
from the object store, the Python method that retrieved it checks to see if
the object is a RayTaskError and if it is then an exception is thrown
propagating the error message.
Currently, we either use the exception attribute or the traceback attribute
but not both.
Attributes:
function_name (str): The name of the function that failed and produced the
RayTaskError.
exception (Exception): The exception object thrown by the failed task.
traceback_str (str): The traceback from the exception.
"""
def __init__(self, function_name, exception, traceback_str):
"""Initialize a RayTaskError."""
self.function_name = function_name
if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError):
self.exception = exception
else:
self.exception = None
self.traceback_str = traceback_str
def __str__(self):
"""Format a RayTaskError as a string."""
if self.traceback_str is None:
# This path is taken if getting the task arguments failed.
return "Remote function {}{}{} failed with:\n\n{}".format(colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.exception)
else:
# This path is taken if the task execution failed.
return "Remote function {}{}{} failed with:\n\n{}".format(colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.traceback_str)
class RayGetError(Exception):
"""An exception used when get is called on an output of a failed task.
Attributes:
objectid (lib.ObjectID): The ObjectID that get was called on.
task_error (RayTaskError): The RayTaskError object created by the failed
task.
"""
def __init__(self, objectid, task_error):
"""Initialize a RayGetError object."""
self.objectid = objectid
self.task_error = task_error
def __str__(self):
"""Format a RayGetError as a string."""
return "Could not get objectid {}. It was created by remote function {}{}{} which failed with:\n\n{}".format(self.objectid, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
class RayGetArgumentError(Exception):
"""An exception used when a task's argument was produced by a failed task.
Attributes:
argument_index (int): The index (zero indexed) of the failed argument in
present task's remote function call.
function_name (str): The name of the function for the current task.
objectid (lib.ObjectID): The ObjectID that was passed in as the argument.
task_error (RayTaskError): The RayTaskError object created by the failed
task.
"""
def __init__(self, function_name, argument_index, objectid, task_error):
"""Initialize a RayGetArgumentError object."""
self.argument_index = argument_index
self.function_name = function_name
self.objectid = objectid
self.task_error = task_error
def __str__(self):
"""Format a RayGetArgumentError as a string."""
return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
class EnvironmentVariable(object):
"""An Python object that can be shared between tasks.
Attributes:
initializer (Callable[[], object]): A function used to create and initialize
the environment variable.
reinitializer (Optional[Callable[[object], object]]): An optional function
used to reinitialize the environment variable after it has been used. This
argument can be used as an optimization if there is a fast way to
reinitialize the state of the variable other than rerunning the
initializer.
"""
def __init__(self, initializer, reinitializer=None):
"""Initialize an EnvironmentVariable object."""
if not callable(initializer):
raise Exception("When creating an EnvironmentVariable, initializer must be a function.")
self.initializer = initializer
if reinitializer is None:
# If no reinitializer is passed in, use a wrapped version of the initializer.
reinitializer = lambda value: initializer()
if not callable(reinitializer):
raise Exception("When creating an EnvironmentVariable, reinitializer must be a function.")
self.reinitializer = reinitializer
class RayEnvironmentVariables(object):
"""An object used to store Python variables that are shared between tasks.
Each worker process will have a single RayEnvironmentVariables object. This
class serves two purposes. First, some objects are not serializable, and so
the code that creates those objects must be run on the worker that uses them.
This class is responsible for running the code that creates those objects.
Second, some of these objects are expensive to create, and so they should be
shared between tasks. However, if a task mutates a variable that is shared
between tasks, then the behavior of the overall program may be
nondeterministic (it could depend on scheduling decisions). To fix this, if a
task uses a one of these shared objects, then that shared object will be
reinitialized after the task finishes. Since the initialization may be
expensive, the user can pass in custom reinitialization code that resets the
state of the shared variable to the way it was after initialization. If the
reinitialization code does not do this, then the behavior of the overall
program is undefined.
Attributes:
_names (List[str]): A list of the names of all the environment variables.
_reinitializers (Dict[str, Callable]): A dictionary mapping the name of the
environment variables to the corresponding reinitializer.
_running_remote_function_locally (bool): A flag used to indicate if a remote
function is running locally on the driver so that we can simulate the same
behavior as running a remote function remotely.
_environment_variables: A dictionary mapping the name of an environment
variable to the value of the environment variable.
_local_mode_environment_variables: A copy of _environment_variables used on
the driver when running remote functions locally on the driver. This is
needed because there are two ways in which environment variables can be
used on the driver. The first is that the driver's copy can be
manipulated. This copy is never reset (think of the driver as a single
long-running task). The second way is that a remote function can be run
locally on the driver, and this remote function needs access to a copy of
the environment variable, and that copy must be reinitialized after use.
_cached_environment_variables (List[Tuple[str, EnvironmentVariable]]): A
list of pairs. The first element of each pair is the name of an
environment variable, and the second element is the EnvironmentVariable
object. This list is used to store environment variables that are defined
before the driver is connected. Once the driver is connected, these
variables will be exported.
_used (List[str]): A list of the names of all the environment variables that
have been accessed within the scope of the current task. This is reset to
the empty list after each task.
"""
def __init__(self):
"""Initialize an RayEnvironmentVariables object."""
self._names = set()
self._reinitializers = {}
self._running_remote_function_locally = False
self._environment_variables = {}
self._local_mode_environment_variables = {}
self._cached_environment_variables = []
self._used = set()
self._slots = ("_names", "_reinitializers", "_running_remote_function_locally", "_environment_variables", "_local_mode_environment_variables", "_cached_environment_variables", "_used", "_slots", "_create_environment_variable", "_reinitialize", "__getattribute__", "__setattr__", "__delattr__")
# CHECKPOINT: Attributes must not be added after _slots. The above attributes are protected from deletion.
def _create_environment_variable(self, name, environment_variable):
"""Create an environment variable locally.
Args:
name (str): The name of the environment variable.
environment_variable (EnvironmentVariable): The environment variable
object to use to create the environment variable variable.
"""
self._names.add(name)
self._reinitializers[name] = environment_variable.reinitializer
self._environment_variables[name] = environment_variable.initializer()
# We create a second copy of the environment variable on the driver to use
# inside of remote functions that run locally. This occurs when we start Ray
# in PYTHON_MODE and when we call a remote function locally.
if _mode() in [SCRIPT_MODE, SILENT_MODE, PYTHON_MODE]:
self._local_mode_environment_variables[name] = environment_variable.initializer()
def _reinitialize(self):
"""Reinitialize the environment variables that the current task used."""
for name in self._used:
current_value = self._environment_variables[name]
new_value = self._reinitializers[name](current_value)
# If we are on the driver, reset the copy of the environment variable in
# the _local_mode_environment_variables dictionary.
if _mode() in [SCRIPT_MODE, SILENT_MODE, PYTHON_MODE]:
assert self._running_remote_function_locally
self._local_mode_environment_variables[name] = new_value
else:
self._environment_variables[name] = new_value
self._used.clear() # Reset the _used list.
def __getattribute__(self, name):
"""Get an attribute. This handles environment variables as a special case.
When __getattribute__ is called with the name of an environment variable,
that name is added to the list of variables that were used in the current
task.
Args:
name (str): The name of the attribute to get.
"""
if name == "_slots":
return object.__getattribute__(self, name)
if name in self._slots:
return object.__getattribute__(self, name)
# Handle various fields that are not environment variables.
if name not in self._names:
return object.__getattribute__(self, name)
# Make a note of the fact that the environment variable has been used.
if name in self._names and name not in self._used:
self._used.add(name)
if self._running_remote_function_locally:
return self._local_mode_environment_variables[name]
else:
return self._environment_variables[name]
def __setattr__(self, name, value):
"""Set an attribute. This handles environment variables as a special case.
This is used to create environment variables. When it is called, it runs the
function for initializing the variable to create the variable. If this is
called on the driver, then the functions for initializing and reinitializing
the variable are shipped to the workers.
If this is called before ray.init has been run, then the environment
variable will be cached and it will be created and exported when connect is
called.
Args:
name (str): The name of the attribute to set. This is either a whitelisted
name or it is treated as the name of an environment variable.
value: If name is a whitelisted name, then value can be any value. If name
is the name of an environment variable, then this is an
EnvironmentVariable object.
"""
try:
slots = self._slots
except AttributeError:
slots = ()
if slots == ():
return object.__setattr__(self, name, value)
if name in slots:
return object.__setattr__(self, name, value)
environment_variable = value
if not issubclass(type(environment_variable), EnvironmentVariable):
raise Exception("To set an environment variable, you must pass in an EnvironmentVariable object")
# If ray.init has not been called, cache the environment variable to export
# later. Otherwise, export the environment variable to the workers and
# define it locally.
if _mode() is None:
self._cached_environment_variables.append((name, environment_variable))
else:
# If we are on the driver, export the environment variable to all the
# workers.
if _mode() in [SCRIPT_MODE, SILENT_MODE]:
_export_environment_variable(name, environment_variable)
# Define the environment variable locally.
self._create_environment_variable(name, environment_variable)
# Create an empty attribute with the name of the environment variable.
# This allows the Python interpreter to do tab complete properly.
object.__setattr__(self, name, None)
def __delattr__(self, name):
"""We do not allow attributes of RayEnvironmentVariables to be deleted.
Args:
name (str): The name of the attribute to delete.
"""
raise Exception("Attempted deletion of attribute {}. Attributes of a RayEnvironmentVariables object may not be deleted.".format(name))
class Worker(object):
"""A class used to define the control flow of a worker process.
Note:
The methods in this class are considered unexposed to the user. The
functions outside of this class are considered exposed.
Attributes:
functions (Dict[str, Callable]): A dictionary mapping the name of a remote
function to the remote function itself. This is the set of remote
functions that can be executed by this worker.
connected (bool): True if Ray has been started and False otherwise.
mode: The mode of the worker. One of SCRIPT_MODE, PYTHON_MODE, SILENT_MODE,
and WORKER_MODE.
cached_remote_functions (List[Tuple[str, str]]): A list of pairs
representing the remote functions that were defined before he worker
called connect. The first element is the name of the remote function, and
the second element is the serialized remote function. When the worker
eventually does call connect, if it is a driver, it will export these
functions to the scheduler. If cached_remote_functions is None, that means
that connect has been called already.
cached_functions_to_run (List): A list of functions to run on all of the
workers that should be exported as soon as connect is called.
driver_export_counter (int): The number of exports that the driver has
exported. This is only used on the driver.
worker_import_counter (int): The number of exports that the worker has
imported so far. This is only used on the workers.
"""
def __init__(self):
"""Initialize a Worker object."""
self.functions = {}
self.num_return_vals = {}
self.function_names = {}
self.function_export_counters = {}
self.connected = False
self.mode = None
self.cached_remote_functions = []
self.cached_functions_to_run = []
self.driver_export_counter = 0
self.worker_import_counter = 0
def set_mode(self, mode):
"""Set the mode of the worker.
The mode SCRIPT_MODE should be used if this Worker is a driver that is being
run as a Python script or interactively in a shell. It will print
information about task failures.
The mode WORKER_MODE should be used if this Worker is not a driver. It will
not print information about tasks.
The mode PYTHON_MODE should be used if this Worker is a driver and if you
want to run the driver in a manner equivalent to serial Python for debugging
purposes. It will not send remote function calls to the scheduler and will
insead execute them in a blocking fashion.
The mode SILENT_MODE should be used only during testing. It does not print
any information about errors because some of the tests intentionally fail.
args:
mode: One of SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, and SILENT_MODE.
"""
self.mode = mode
colorama.init()
def put_object(self, objectid, value):
"""Put value in the local object store with object id objectid.
This assumes that the value for objectid has not yet been placed in the
local object store.
Args:
objectid (object_id.ObjectID): The object ID of the value to be put.
value (serializable object): The value to put in the object store.
"""
# Serialize and put the object in the object store.
try:
numbuf.store_list(objectid.id(), self.plasma_client.conn, [value])
except plasma.plasma_object_exists_error as e:
# The object already exists in the object store, so there is no need to
# add it again. TODO(rkn): We need to compare the hashes and make sure
# that the objects are in fact the same. We also should return an error
# code to the caller instead of printing a message.
print("This object already exists in the object store.")
return
global contained_objectids
# Optionally do something with the contained_objectids here.
contained_objectids = []
def get_object(self, object_ids):
"""Get the value or values in the local object store associated with object_ids.
Return the values from the local object store for object_ids. This will block
until all the values for object_ids have been written to the local object store.
Args:
object_ids (List[object_id.ObjectID]): A list of the object IDs whose
values should be retrieved.
"""
self.plasma_client.fetch([object_id.id() for object_id in object_ids])
# We currently pass in a timeout of one second.
unready_ids = object_ids
while len(unready_ids) > 0:
results = numbuf.retrieve_list([object_id.id() for object_id in object_ids], self.plasma_client.conn, 1000)
unready_ids = [object_id for (object_id, val) in results if val is None]
# This would be a natural place to issue a command to reconstruct some of
# the objects.
# Unwrap the object from the list (it was wrapped put_object).
assert len(results) == len(object_ids)
for i in range(len(results)):
assert results[i][0] == object_ids[i].id()
return [result[1][0] for result in results]
def submit_task(self, function_id, func_name, args):
"""Submit a remote task to the scheduler.
Tell the scheduler to schedule the execution of the function with name
func_name with arguments args. Retrieve object IDs for the outputs of
the function from the scheduler and immediately return them.
Args:
func_name (str): The name of the function to be executed.
args (List[Any]): The arguments to pass into the function. Arguments can
be object IDs or they can be values. If they are values, they
must be serializable objecs.
"""
with log_span("ray:submit_task", worker=self):
check_main_thread()
# Put large or complex arguments that are passed by value in the object
# store first.
args_for_photon = []
for arg in args:
if isinstance(arg, photon.ObjectID):
args_for_photon.append(arg)
elif photon.check_simple_value(arg):
args_for_photon.append(arg)
else:
args_for_photon.append(put(arg))
# Submit the task to Photon.
task = photon.Task(self.task_driver_id,
photon.ObjectID(function_id.id()),
args_for_photon,
self.num_return_vals[function_id.id()],
self.current_task_id,
self.task_index)
# Increment the worker's task index to track how many tasks have been
# submitted by the current task so far.
self.task_index += 1
self.photon_client.submit(task)
return task.returns()
def run_function_on_all_workers(self, function):
"""Run arbitrary code on all of the workers.
This function will first be run on the driver, and then it will be exported
to all of the workers to be run. It will also be run on any new workers that
register later. If ray.init has not been called yet, then cache the function
and export it later.
Args:
function (Callable): The function to run on all of the workers. It should
not take any arguments. If it returns anything, its return values will
not be used.
"""
check_main_thread()
if self.mode not in [None, SCRIPT_MODE, SILENT_MODE, PYTHON_MODE]:
raise Exception("run_function_on_all_workers can only be called on a driver.")
# If ray.init has not been called yet, then cache the function and export it
# when connect is called. Otherwise, run the function on all workers.
if self.mode is None:
self.cached_functions_to_run.append(function)
else:
function_to_run_id = random_string()
key = "FunctionsToRun:{}".format(function_to_run_id)
# First run the function on the driver. Pass in the number of workers on
# this node that have already started executing this remote function,
# and increment that value. Subtract 1 so that the counter starts at 0.
counter = self.redis_client.hincrby(self.node_ip_address, key, 1) - 1
function({"counter": counter})
# Run the function on all workers.
self.redis_client.hmset(key, {"driver_id": self.task_driver_id.id(),
"function_id": function_to_run_id,
"function": pickling.dumps(function)})
self.redis_client.rpush("Exports", key)
self.driver_export_counter += 1
def push_error_to_driver(self, driver_id, error_type, message, data=None):
"""Push an error message to the driver to be printed in the background.
Args:
driver_id: The ID of the driver to push the error message to.
error_type (str): The type of the error.
message (str): The message that will be printed in the background on the
driver.
data: This should be a dictionary mapping strings to strings. It will be
serialized with json and stored in Redis.
"""
error_key = ERROR_KEY_PREFIX + driver_id + b":" + random_string()
data = {} if data is None else data
self.redis_client.hmset(error_key, {"type": error_type,
"message": message,
"data": data})
self.redis_client.rpush("ErrorKeys", error_key)
global_worker = Worker()
"""Worker: The global Worker object for this worker process.
We use a global Worker object to ensure that there is a single worker object
per worker process.
"""
env = RayEnvironmentVariables()
"""RayEnvironmentVariables: The environment variables that are shared by tasks.
Each worker process has its own RayEnvironmentVariables object, and these
objects should be the same in all workers. This is used for storing variables
that are not serializable but must be used by remote tasks. In addition, it is
used to reinitialize these variables after they are used so that changes to
their state made by one task do not affect other tasks.
"""
class RayConnectionError(Exception):
pass
def check_main_thread():
"""Check that we are currently on the main thread.
Raises:
Exception: An exception is raised if this is called on a thread other than
the main thread.
"""
if threading.current_thread().getName() != "MainThread":
raise Exception("The Ray methods are not thread safe and must be called from the main thread. This method was called from thread {}.".format(threading.current_thread().getName()))
def check_connected(worker=global_worker):
"""Check if the worker is connected.
Raises:
Exception: An exception is raised if the worker is not connected.
"""
if not worker.connected:
raise RayConnectionError("This command cannot be called before Ray has been started. You can start Ray with 'ray.init(num_workers=10)'.")
def print_failed_task(task_status):
"""Print information about failed tasks.
Args:
task_status (Dict): A dictionary containing the name, operationid, and
error message for a failed task.
"""
print("""
Error: Task failed
Function Name: {}
Task ID: {}
Error Message: \n{}
""".format(task_status["function_name"], task_status["operationid"], task_status["error_message"]))
def error_applies_to_driver(error_key, worker=global_worker):
"""Return True if the error is for this driver and false otherwise."""
# TODO(rkn): Should probably check that this is only called on a driver.
# Check that the error key is formatted as in push_error_to_driver.
assert len(error_key) == len(ERROR_KEY_PREFIX) + DRIVER_ID_LENGTH + 1 + ERROR_ID_LENGTH, error_key
# If the driver ID in the error message is a sequence of all zeros, then the
# message is intended for all drivers.
generic_driver_id = DRIVER_ID_LENGTH * b"\x00"
driver_id = error_key[len(ERROR_KEY_PREFIX):(len(ERROR_KEY_PREFIX) + DRIVER_ID_LENGTH)]
return driver_id == worker.task_driver_id.id() or driver_id == generic_driver_id
def error_info(worker=global_worker):
"""Return information about failed tasks."""
check_connected(worker)
check_main_thread()
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
errors = []
for error_key in error_keys:
if error_applies_to_driver(error_key, worker=worker):
error_contents = worker.redis_client.hgetall(error_key)
errors.append(error_contents)
return errors
def initialize_numbuf(worker=global_worker):
"""Initialize the serialization library.
This defines a custom serializer for object IDs and also tells numbuf to
serialize several exception classes that we define for error handling.
"""
# Define a custom serializer and deserializer for handling Object IDs.
def objectid_custom_serializer(obj):
class_identifier = serialization.class_identifier(type(obj))
contained_objectids.append(obj)
return obj.id()
def objectid_custom_deserializer(serialized_obj):
return photon.ObjectID(serialized_obj)
serialization.add_class_to_whitelist(photon.ObjectID, pickle=False, custom_serializer=objectid_custom_serializer, custom_deserializer=objectid_custom_deserializer)
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
# These should only be called on the driver because register_class will
# export the class to all of the workers.
register_class(RayTaskError)
register_class(RayGetError)
register_class(RayGetArgumentError)
def get_address_info_from_redis_helper(redis_address, node_ip_address):
redis_ip_address, redis_port = redis_address.split(":")
# For this command to work, some other client (on the same machine as Redis)
# must have run "CONFIG SET protected-mode no".
redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
# The client table prefix must be kept in sync with the file
# "src/common/redis_module/ray_redis_module.c" where it is defined.
REDIS_CLIENT_TABLE_PREFIX = "CL:"
client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX))
# Filter to clients on the same node and do some basic checking.
plasma_managers = []
local_schedulers = []
for key in client_keys:
info = redis_client.hgetall(key)
assert b"ray_client_id" in info
assert b"node_ip_address" in info
assert b"client_type" in info
if info[b"node_ip_address"].decode("ascii") == node_ip_address:
if info[b"client_type"].decode("ascii") == "plasma_manager":
plasma_managers.append(info)
elif info[b"client_type"].decode("ascii") == "photon":
local_schedulers.append(info)
# Make sure that we got at one plasma manager and local scheduler.
assert len(plasma_managers) >= 1
assert len(local_schedulers) >= 1
# Build the address information.
object_store_addresses = []
for manager in plasma_managers:
address = manager[b"address"].decode("ascii")
port = services.get_port(address)
object_store_addresses.append(
services.ObjectStoreAddress(
name=manager[b"store_socket_name"].decode("ascii"),
manager_name=manager[b"manager_socket_name"].decode("ascii"),
manager_port=port
)
)
scheduler_names = [scheduler[b"local_scheduler_socket_name"].decode("ascii")
for scheduler in local_schedulers]
client_info = {"node_ip_address": node_ip_address,
"redis_address": redis_address,
"object_store_addresses": object_store_addresses,
"local_scheduler_socket_names": scheduler_names,
}
return client_info
def get_address_info_from_redis(redis_address, node_ip_address, num_retries=5):
counter = 0
while True:
try:
return get_address_info_from_redis_helper(redis_address, node_ip_address)
except Exception as e:
if counter == num_retries:
raise
# Some of the information may not be in Redis yet, so wait a little bit.
print("Some processes that the driver needs to connect to have not "
"registered with Redis, so retrying. Have you run "
"./scripts/start_ray.sh on this node?")
time.sleep(1)
counter += 1
def _init(address_info=None, start_ray_local=False, object_id_seed=None,
num_workers=None, num_local_schedulers=None,
driver_mode=SCRIPT_MODE):
"""Helper method to connect to an existing Ray cluster or start a new one.
This method handles two cases. Either a Ray cluster already exists and we
just attach this driver to it, or we start all of the processes associated
with a Ray cluster and attach to the newly started cluster.
Args:
address_info (dict): A dictionary with address information for processes in
a partially-started Ray cluster. If start_ray_local=True, any processes
not in this dictionary will be started. If provided, address_info will be
modified to include processes that are newly started.
start_ray_local (bool): If True then this will start any processes not
already in address_info, including Redis, a global scheduler, local
scheduler(s), object store(s), and worker(s). It will also kill these
processes when Python exits. If False, this will attach to an existing
Ray cluster.
object_id_seed (int): Used to seed the deterministic generation of object
IDs. The same value can be used across multiple runs of the same job in
order to generate the object IDs in a consistent manner. However, the same
ID should not be used for different jobs.
num_workers (int): The number of workers to start. This is only provided if
start_ray_local is True.
num_local_schedulers (int): The number of local schedulers to start. This is
only provided if start_ray_local is True.
driver_mode (bool): The mode in which to start the driver. This should be
one of ray.SCRIPT_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE.
Returns:
Address information about the started processes.
Raises:
Exception: An exception is raised if an inappropriate combination of
arguments is passed in.
"""
check_main_thread()
if driver_mode not in [SCRIPT_MODE, PYTHON_MODE, SILENT_MODE]:
raise Exception("Driver_mode must be in [ray.SCRIPT_MODE, ray.PYTHON_MODE, ray.SILENT_MODE].")
# Get addresses of existing services.
if address_info is None:
address_info = {}
else:
assert isinstance(address_info, dict)
node_ip_address = address_info.get("node_ip_address")
redis_address = address_info.get("redis_address")
# Start any services that do not yet exist.
if driver_mode == PYTHON_MODE:
# If starting Ray in PYTHON_MODE, don't start any other processes.
pass
elif start_ray_local:
# In this case, we launch a scheduler, a new object store, and some workers,
# and we connect to them. We do not launch any processes that are already
# registered in address_info.
# Use the address 127.0.0.1 in local mode.
node_ip_address = "127.0.0.1" if node_ip_address is None else node_ip_address
# Use 1 worker if num_workers is not provided.
num_workers = 10 if num_workers is None else num_workers
# Use 1 local scheduler if num_local_schedulers is not provided. If
# existing local schedulers are provided, use that count as
# num_local_schedulers.
local_schedulers = address_info.get("local_scheduler_socket_names", [])
if num_local_schedulers is None:
if len(local_schedulers) > 0:
num_local_schedulers = len(local_schedulers)
else:
num_local_schedulers = 1
# Start the scheduler, object store, and some workers. These will be killed
# by the call to cleanup(), which happens when the Python script exits.
address_info = services.start_ray_head(address_info=address_info,
node_ip_address=node_ip_address,
num_workers=num_workers,
num_local_schedulers=num_local_schedulers)
else:
if redis_address is None:
raise Exception("If start_ray_local=False, then redis_address must be provided.")
if num_workers is not None:
raise Exception("If start_ray_local=False, then num_workers must not be provided.")
if num_local_schedulers is not None:
raise Exception("If start_ray_local=False, then num_local_schedulers must not be provided.")
# Get the node IP address if one is not provided.
if node_ip_address is None:
node_ip_address = services.get_node_ip_address(redis_address)
# Get the address info of the processes to connect to from Redis.
address_info = get_address_info_from_redis(redis_address, node_ip_address)
# Connect this driver to Redis, the object store, and the local scheduler.
# Choose the first object store and local scheduler if there are multiple.
# The corresponding call to disconnect will happen in the call to cleanup()
# when the Python script exits.
if driver_mode == PYTHON_MODE:
driver_address_info = {}
else:
driver_address_info = {
"node_ip_address": node_ip_address,
"redis_address": address_info["redis_address"],
"store_socket_name": address_info["object_store_addresses"][0].name,
"manager_socket_name": address_info["object_store_addresses"][0].manager_name,
"local_scheduler_socket_name": address_info["local_scheduler_socket_names"][0],
}
connect(driver_address_info, object_id_seed=object_id_seed, mode=driver_mode, worker=global_worker)
return address_info
def init(redis_address=None, node_ip_address=None, object_id_seed=None,
num_workers=None, driver_mode=SCRIPT_MODE):
"""Either connect to an existing Ray cluster or start one and connect to it.
This method handles two cases. Either a Ray cluster already exists and we
just attach this driver to it, or we start all of the processes associated
with a Ray cluster and attach to the newly started cluster.
Args:
node_ip_address (str): The IP address of the node that we are on.
redis_address (str): The address of the Redis server to connect to. If this
address is not provided, then this command will start Redis, a global
scheduler, a local scheduler, a plasma store, a plasma manager, and some
workers. It will also kill these processes when Python exits.
object_id_seed (int): Used to seed the deterministic generation of object
IDs. The same value can be used across multiple runs of the same job in
order to generate the object IDs in a consistent manner. However, the same
ID should not be used for different jobs.
num_workers (int): The number of workers to start. This is only provided if
redis_address is not provided.
driver_mode (bool): The mode in which to start the driver. This should be
one of ray.SCRIPT_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE.
Returns:
Address information about the started processes.
Raises:
Exception: An exception is raised if an inappropriate combination of
arguments is passed in.
"""
info = {
"node_ip_address": node_ip_address,
"redis_address": redis_address,
}
return _init(address_info=info, start_ray_local=(redis_address is None),
num_workers=num_workers, driver_mode=driver_mode)
def cleanup(worker=global_worker):
"""Disconnect the driver, and terminate any processes started in init.
This will automatically run at the end when a Python process that uses Ray
exits. It is ok to run this twice in a row. Note that we manually call
services.cleanup() in the tests because we need to start and stop many
clusters in the tests, but the import and exit only happen once.
"""
disconnect(worker)
worker.set_mode(None)
worker.driver_export_counter = 0
worker.worker_import_counter = 0
if hasattr(worker, "plasma_client"):
worker.plasma_client.shutdown()
services.cleanup()
atexit.register(cleanup)
def print_error_messages(worker):
"""Print error messages in the background on the driver.
This runs in a separate thread on the driver and prints error messages in the
background.
"""
# TODO(rkn): All error messages should have a "component" field indicating
# which process the error came from (e.g., a worker or a plasma store).
# Currently all error messages come from workers.
helpful_message = """
You can inspect errors by running
ray.error_info()
If this driver is hanging, start a new one with
ray.init(redis_address="{}")
""".format(worker.redis_address)
worker.error_message_pubsub_client = worker.redis_client.pubsub()
# Exports that are published after the call to
# error_message_pubsub_client.psubscribe and before the call to
# error_message_pubsub_client.listen will still be processed in the loop.
worker.error_message_pubsub_client.psubscribe("__keyspace@0__:ErrorKeys")
num_errors_received = 0
# Get the exports that occurred before the call to psubscribe.
with worker.lock:
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
for error_key in error_keys:
if error_applies_to_driver(error_key, worker=worker):
error_message = worker.redis_client.hget(error_key, "message").decode("ascii")
print(error_message)
print(helpful_message)
num_errors_received += 1
try:
for msg in worker.error_message_pubsub_client.listen():
with worker.lock:
for error_key in worker.redis_client.lrange("ErrorKeys", num_errors_received, -1):
if error_applies_to_driver(error_key, worker=worker):
error_message = worker.redis_client.hget(error_key, "message").decode("ascii")
print(error_message)
print(helpful_message)
num_errors_received += 1
except redis.ConnectionError:
# When Redis terminates the listen call will throw a ConnectionError, which
# we catch here.
pass
def fetch_and_register_remote_function(key, worker=global_worker):
"""Import a remote function."""
driver_id, function_id_str, function_name, serialized_function, num_return_vals, module, function_export_counter = worker.redis_client.hmget(key, ["driver_id", "function_id", "name", "function", "num_return_vals", "module", "function_export_counter"])
function_id = photon.ObjectID(function_id_str)
function_name = function_name.decode("ascii")
num_return_vals = int(num_return_vals)
module = module.decode("ascii")
function_export_counter = int(function_export_counter)
worker.function_names[function_id.id()] = function_name
worker.num_return_vals[function_id.id()] = num_return_vals
worker.function_export_counters[function_id.id()] = function_export_counter
# This is a placeholder in case the function can't be unpickled. This will be
# overwritten if the function is unpickled successfully.
def f():
raise Exception("This function was not imported properly.")
worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals, function_id=function_id)(lambda *xs: f())
try:
function = pickling.loads(serialized_function)
except:
# If an exception was thrown when the remote function was imported, we
# record the traceback and notify the scheduler of the failure.
traceback_str = format_error_message(traceback.format_exc())
# Log the error message.
worker.push_error_to_driver(driver_id, "register_remote_function",
traceback_str,
data={"function_id": function_id.id(),
"function_name": function_name})
else:
# TODO(rkn): Why is the below line necessary?
function.__module__ = module
worker.functions[function_id.id()] = remote(num_return_vals=num_return_vals, function_id=function_id)(function)
# Add the function to the function table.
worker.redis_client.rpush("FunctionTable:{}".format(function_id.id()), worker.worker_id)
def fetch_and_register_environment_variable(key, worker=global_worker):
"""Import an environment variable."""
driver_id, environment_variable_name, serialized_initializer, serialized_reinitializer = worker.redis_client.hmget(key, ["driver_id", "name", "initializer", "reinitializer"])
environment_variable_name = environment_variable_name.decode("ascii")
try:
initializer = pickling.loads(serialized_initializer)
reinitializer = pickling.loads(serialized_reinitializer)
env.__setattr__(environment_variable_name, EnvironmentVariable(initializer, reinitializer))
except:
# If an exception was thrown when the environment variable was imported, we
# record the traceback and notify the scheduler of the failure.
traceback_str = format_error_message(traceback.format_exc())
# Log the error message.
worker.push_error_to_driver(driver_id, "register_environment_variable",
traceback_str,
data={"name": environment_variable_name})
def fetch_and_execute_function_to_run(key, worker=global_worker):
"""Run on arbitrary function on the worker."""
driver_id, serialized_function = worker.redis_client.hmget(key, ["driver_id", "function"])
# Get the number of workers on this node that have already started executing
# this remote function, and increment that value. Subtract 1 so the counter
# starts at 0.
counter = worker.redis_client.hincrby(worker.node_ip_address, key, 1) - 1
try:
# Deserialize the function.
function = pickling.loads(serialized_function)
# Run the function.
function({"counter": counter})
except:
# If an exception was thrown when the function was run, we record the
# traceback and notify the scheduler of the failure.
traceback_str = traceback.format_exc()
# Log the error message.
name = function.__name__ if "function" in locals() and hasattr(function, "__name__") else ""
worker.push_error_to_driver(driver_id, "function_to_run", traceback_str,
data={"name": name})
def import_thread(worker):
worker.import_pubsub_client = worker.redis_client.pubsub()
# Exports that are published after the call to import_pubsub_client.psubscribe
# and before the call to import_pubsub_client.listen will still be processed
# in the loop.
worker.import_pubsub_client.psubscribe("__keyspace@0__:Exports")
worker_info_key = "WorkerInfo:{}".format(worker.worker_id)
worker.redis_client.hset(worker_info_key, "export_counter", 0)
worker.worker_import_counter = 0
# Get the exports that occurred before the call to psubscribe.
with worker.lock:
export_keys = worker.redis_client.lrange("Exports", 0, -1)
for key in export_keys:
if key.startswith(b"RemoteFunction"):
fetch_and_register_remote_function(key, worker=worker)
elif key.startswith(b"EnvironmentVariables"):
fetch_and_register_environment_variable(key, worker=worker)
elif key.startswith(b"FunctionsToRun"):
fetch_and_execute_function_to_run(key, worker=worker)
else:
raise Exception("This code should be unreachable.")
worker.redis_client.hincrby(worker_info_key, "export_counter", 1)
worker.worker_import_counter += 1
for msg in worker.import_pubsub_client.listen():
with worker.lock:
if msg["type"] == "psubscribe":
continue
assert msg["data"] == b"rpush"
num_imports = worker.redis_client.llen("Exports")
assert num_imports >= worker.worker_import_counter
for i in range(worker.worker_import_counter, num_imports):
key = worker.redis_client.lindex("Exports", i)
if key.startswith(b"RemoteFunction"):
with log_span("ray:import_remote_function", worker=worker):
fetch_and_register_remote_function(key, worker=worker)
elif key.startswith(b"EnvironmentVariables"):
with log_span("ray:import_environment_variable", worker=worker):
fetch_and_register_environment_variable(key, worker=worker)
elif key.startswith(b"FunctionsToRun"):
with log_span("ray:import_function_to_run", worker=worker):
fetch_and_execute_function_to_run(key, worker=worker)
else:
raise Exception("This code should be unreachable.")
worker.redis_client.hincrby(worker_info_key, "export_counter", 1)
worker.worker_import_counter += 1
def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker):
"""Connect this worker to the local scheduler, to Plasma, and to Redis.
Args:
info (dict): A dictionary with address of the Redis server and the sockets
of the plasma store, plasma manager, and local scheduler.
mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE, PYTHON_MODE,
and SILENT_MODE.
"""
check_main_thread()
# Do some basic checking to make sure we didn't call ray.init twice.
error_message = "Perhaps you called ray.init twice by accident?"
assert not worker.connected, error_message
assert worker.cached_functions_to_run is not None, error_message
assert worker.cached_remote_functions is not None, error_message
assert env._cached_environment_variables is not None, error_message
# Initialize some fields.
worker.worker_id = random_string()
worker.connected = True
worker.set_mode(mode)
# The worker.events field is used to aggregate logging information and display
# it in the web UI. Note that Python lists protected by the GIL, which is
# important because we will append to this field from multiple threads.
worker.events = []
# If running Ray in PYTHON_MODE, there is no need to create call create_worker
# or to start the worker service.
if mode == PYTHON_MODE:
return
# Set the node IP address.
worker.node_ip_address = info["node_ip_address"]
worker.redis_address = info["redis_address"]
# Create a Redis client.
redis_ip_address, redis_port = info["redis_address"].split(":")
worker.redis_client = redis.StrictRedis(host=redis_ip_address, port=int(redis_port))
worker.lock = threading.Lock()
# Create an object store client.
worker.plasma_client = plasma.PlasmaClient(info["store_socket_name"], info["manager_socket_name"])
# Create the local scheduler client.
worker.photon_client = photon.PhotonClient(info["local_scheduler_socket_name"])
# Register the worker with Redis.
if mode in [SCRIPT_MODE, SILENT_MODE]:
worker.redis_client.hmset(b"Drivers:" + worker.worker_id, {"node_ip_address": worker.node_ip_address})
elif mode == WORKER_MODE:
worker.redis_client.hmset(b"Workers:" + worker.worker_id, {"node_ip_address": worker.node_ip_address})
else:
raise Exception("This code should be unreachable.")
# If this is a driver, set the current task ID, the task driver ID, and set
# the task index to 0.
if mode in [SCRIPT_MODE, SILENT_MODE]:
# If the user provided an object_id_seed, then set the current task ID
# deterministically based on that seed (without altering the state of the
# user's random number generator). Otherwise, set the current task ID
# randomly to avoid object ID collisions.
numpy_state = np.random.get_state()
if object_id_seed is not None:
np.random.seed(object_id_seed)
else:
# Try to use true randomness.
np.random.seed(None)
worker.current_task_id = photon.ObjectID(np.random.bytes(20))
# When tasks are executed on remote workers in the context of multiple
# drivers, the task driver ID is used to keep track of which driver is
# responsible for the task so that error messages will be propagated to the
# correct driver.
worker.task_driver_id = photon.ObjectID(worker.worker_id)
# Reset the state of the numpy random number generator.
np.random.set_state(numpy_state)
# Set other fields needed for computing task IDs.
worker.task_index = 0
worker.put_index = 0
# If this is a worker, then start a thread to import exports from the driver.
if mode == WORKER_MODE:
t = threading.Thread(target=import_thread, args=(worker,))
# Making the thread a daemon causes it to exit when the main thread exits.
t.daemon = True
t.start()
# If this is a driver running in SCRIPT_MODE, start a thread to print error
# messages asynchronously in the background. Ideally the scheduler would push
# messages to the driver's worker service, but we ran into bugs when trying to
# properly shutdown the driver's worker service, so we are temporarily using
# this implementation which constantly queries the scheduler for new error
# messages.
if mode == SCRIPT_MODE:
t = threading.Thread(target=print_error_messages, args=(worker,))
# Making the thread a daemon causes it to exit when the main thread exits.
t.daemon = True
t.start()
# Initialize the serialization library. This registers some classes, and so
# it must be run before we export all of the cached remote functions.
initialize_numbuf()
if mode in [SCRIPT_MODE, SILENT_MODE]:
# Add the directory containing the script that is running to the Python
# paths of the workers. Also add the current directory. Note that this
# assumes that the directory structures on the machines in the clusters are
# the same.
script_directory = os.path.abspath(os.path.dirname(sys.argv[0]))
current_directory = os.path.abspath(os.path.curdir)
worker.run_function_on_all_workers(lambda worker_info: sys.path.insert(1, script_directory))
worker.run_function_on_all_workers(lambda worker_info: sys.path.insert(1, current_directory))
# TODO(rkn): Here we first export functions to run, then environment
# variables, then remote functions. The order matters. For example, one of
# the functions to run may set the Python path, which is needed to import a
# module used to define an environment variable, which in turn is used
# inside a remote function. We may want to change the order to simply be the
# order in which the exports were defined on the driver. In addition, we
# will need to retain the ability to decide what the first few exports are
# (mostly to set the Python path). Additionally, note that the first exports
# to be defined on the driver will be the ones defined in separate modules
# that are imported by the driver.
# Export cached functions_to_run.
for function in worker.cached_functions_to_run:
worker.run_function_on_all_workers(function)
# Export cached environment variables to the workers.
for name, environment_variable in env._cached_environment_variables:
env.__setattr__(name, environment_variable)
# Export cached remote functions to the workers.
for function_id, func_name, func, num_return_vals in worker.cached_remote_functions:
export_remote_function(function_id, func_name, func, num_return_vals, worker)
worker.cached_functions_to_run = None
worker.cached_remote_functions = None
env._cached_environment_variables = None
def disconnect(worker=global_worker):
"""Disconnect this worker from the scheduler and object store."""
# Reset the list of cached remote functions so that if more remote functions
# are defined and then connect is called again, the remote functions will be
# exported. This is mostly relevant for the tests.
worker.connected = False
worker.cached_functions_to_run = []
worker.cached_remote_functions = []
env._cached_environment_variables = []
def register_class(cls, pickle=False, worker=global_worker):
"""Enable workers to serialize or deserialize objects of a particular class.
This method runs the register_class function defined below on every worker,
which will enable numbuf to properly serialize and deserialize objects of this
class.
Args:
cls (type): The class that numbuf should serialize.
pickle (bool): If False then objects of this class will be serialized by
turning their __dict__ fields into a dictionary. If True, then objects
of this class will be serialized using pickle.
Raises:
Exception: An exception is raised if pickle=False and the class cannot be
efficiently serialized by Ray.
"""
# If the worker is not a driver, then return. We do this so that Python
# modules can register classes and these modules can be imported on workers
# without any trouble.
if worker.mode == WORKER_MODE:
return
# Raise an exception if cls cannot be serialized efficiently by Ray.
if not pickle:
serialization.check_serializable(cls)
def register_class_for_serialization(worker_info):
serialization.add_class_to_whitelist(cls, pickle=pickle)
worker.run_function_on_all_workers(register_class_for_serialization)
class RayLogSpan(object):
"""An object used to enable logging a span of events with a with statement.
Attributes:
event_type (str): The type of the event being logged.
contents: Additional information to log.
"""
def __init__(self, event_type, contents=None, worker=global_worker):
"""Initialize a RayLogSpan object."""
self.event_type = event_type
self.contents = contents
self.worker = worker
def __enter__(self):
"""Log the beginning of a span event."""
log(event_type=self.event_type,
contents=self.contents,
kind=LOG_SPAN_START,
worker=self.worker)
def __exit__(self, type, value, tb):
"""Log the end of a span event. Log any exception that occurred."""
if type is None:
log(event_type=self.event_type, kind=LOG_SPAN_END, worker=self.worker)
else:
log(event_type=self.event_type,
contents={"type": str(type),
"value": value,
"traceback": traceback.format_exc()},
kind=LOG_SPAN_END,
worker=self.worker)
def log_span(event_type, contents=None, worker=global_worker):
return RayLogSpan(event_type, contents=contents, worker=worker)
def log_event(event_type, contents=None, worker=global_worker):
log(event_type, kind=LOG_POINT, contents=contents, worker=worker)
def log(event_type, kind, contents=None, worker=global_worker):
"""Log an event to the global state store.
This adds the event to a buffer of events locally. The buffer can be flushed
and written to the global state store by calling flush_log().
Args:
event_type (str): The type of the event.
contents: More general data to store with the event.
kind (int): Either LOG_POINT, LOG_SPAN_START, or LOG_SPAN_END. This is
LOG_POINT if the event being logged happens at a single point in time. It
is LOG_SPAN_START if we are starting to log a span of time, and it is
LOG_SPAN_END if we are finishing logging a span of time.
"""
# TODO(rkn): This code currently takes around half a microsecond. Since we
# call it tens of times per task, this adds up. We will need to redo the
# logging code, perhaps in C.
contents = {} if contents is None else contents
assert isinstance(contents, dict)
# Make sure all of the keys and values in the dictionary are strings.
contents = {str(k): str(v) for k, v in contents.items()}
worker.events.append((time.time(), event_type, kind, contents))
def flush_log(worker=global_worker):
"""Send the logged worker events to the global state store."""
event_log_key = b"event_log:" + worker.worker_id + b":" + worker.current_task_id.id()
event_log_value = json.dumps(worker.events)
worker.photon_client.log_event(event_log_key, event_log_value)
worker.events = []
def get(object_ids, worker=global_worker):
"""Get a remote object or a list of remote objects from the object store.
This method blocks until the object corresponding to the object ID is available in
the local object store. If this object is not in the local object store, it
will be shipped from an object store that has it (once the object has been
created). If object_ids is a list, then the objects corresponding to each object
in the list will be returned.
Args:
object_ids: Object ID of the object to get or a list of object IDs to get.
Returns:
A Python object or a list of Python objects.
"""
check_connected(worker)
with log_span("ray:get", worker=worker):
check_main_thread()
if worker.mode == PYTHON_MODE:
# In PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objectid)
return object_ids
if isinstance(object_ids, list):
values = worker.get_object(object_ids)
for i, value in enumerate(values):
if isinstance(value, RayTaskError):
raise RayGetError(object_ids[i], value)
return values
else:
value = worker.get_object([object_ids])[0]
if isinstance(value, RayTaskError):
# If the result is a RayTaskError, then the task that created this object
# failed, and we should propagate the error message here.
raise RayGetError(object_ids, value)
return value
def put(value, worker=global_worker):
"""Store an object in the object store.
Args:
value (serializable object): The Python object to be stored.
Returns:
The object ID assigned to this value.
"""
check_connected(worker)
with log_span("ray:put", worker=worker):
check_main_thread()
if worker.mode == PYTHON_MODE:
# In PYTHON_MODE, ray.put is the identity operation
return value
object_id = photon.compute_put_id(worker.current_task_id, worker.put_index)
worker.put_object(object_id, value)
worker.put_index += 1
return object_id
def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
"""Return a list of IDs that are ready and a list of IDs that are not ready.
If timeout is set, the function returns either when the requested number of
IDs are ready or when the timeout is reached, whichever occurs first. If it is
not set, the function simply waits until that number of objects is ready and
returns that exact number of objectids.
This method returns two lists. The first list consists of object IDs that
correspond to objects that are stored in the object store. The second list
corresponds to the rest of the object IDs (which may or may not be ready).
Args:
object_ids (List[ObjectID]): List of object IDs for objects that may
or may not be ready.
num_returns (int): The number of object IDs that should be returned.
timeout (int): The maximum amount of time in milliseconds to wait before
returning.
Returns:
A list of object IDs that are ready and a list of the remaining object IDs.
"""
check_connected(worker)
with log_span("ray:wait", worker=worker):
check_main_thread()
object_id_strs = [object_id.id() for object_id in object_ids]
timeout = timeout if timeout is not None else 2 ** 30
ready_ids, remaining_ids = worker.plasma_client.wait(object_id_strs, timeout, num_returns)
ready_ids = [photon.ObjectID(object_id) for object_id in ready_ids]
remaining_ids = [photon.ObjectID(object_id) for object_id in remaining_ids]
return ready_ids, remaining_ids
def wait_for_valid_import_counter(function_id, driver_id, timeout=5, worker=global_worker):
"""Wait until this worker has imported enough to execute the function.
This method will simply loop until the import thread has imported enough of
the exports to execute the function. If we spend too long in this loop, that
may indicate a problem somewhere and we will push an error message to the
user.
Args:
function_id (str): The ID of the function that we want to execute.
driver_id (str): The ID of the driver to push the error message to if this
times out.
"""
start_time = time.time()
# Only send the warning once.
warning_sent = False
num_warnings_sent = 0
while True:
with worker.lock:
if function_id.id() in worker.functions and (worker.function_export_counters[function_id.id()] <= worker.worker_import_counter):
break
if time.time() - start_time > timeout * (num_warnings_sent + 1):
if function_id.id() not in worker.functions:
warning_message = "This worker was asked to execute a function that it does not have registered. You may have to restart Ray."
else:
warning_message = "This worker's import counter is too small."
if not warning_sent:
worker.push_error_to_driver(driver_id, "import_counter",
warning_message)
warning_sent = True
time.sleep(0.001)
def format_error_message(exception_message, task_exception=False):
"""Improve the formatting of an exception thrown by a remote function.
This method takes a traceback from an exception and makes it nicer by
removing a few uninformative lines and adding some space to indent the
remaining lines nicely.
Args:
exception_message (str): A message generated by traceback.format_exc().
Returns:
A string of the formatted exception message.
"""
lines = exception_message.split("\n")
if task_exception:
# For errors that occur inside of tasks, remove lines 1, 2, 3, and 4,
# which are always the same, they just contain information about the main
# loop.
lines = lines[0:1] + lines[5:]
return "\n".join(lines)
def main_loop(worker=global_worker):
"""The main loop a worker runs to receive and execute tasks.
This method is an infinite loop. It waits to receive commands from the
scheduler. A command may consist of a task to execute, a remote function to
import, an environment variable to import, or an order to terminate the worker
process. The worker executes the command, notifies the scheduler of any errors
that occurred while executing the command, and waits for the next command.
"""
def process_task(task): # wrapping these lines in a function should cause the local variables to go out of scope more quickly, which is useful for inspecting reference counts
"""Execute a task assigned to this worker.
This method deserializes a task from the scheduler, and attempts to execute
the task. If the task succeeds, the outputs are stored in the local object
store. If the task throws an exception, RayTaskError objects are stored in
the object store to represent the failed task (these will be retrieved by
calls to get or by subsequent tasks that use the outputs of this task).
After the task executes, the worker resets any environment variables that
were accessed by the task.
"""
try:
# The ID of the driver that this task belongs to. This is needed so that
# if the task throws an exception, we propagate the error message to the
# correct driver.
worker.task_driver_id = task.driver_id()
worker.current_task_id = task.task_id()
worker.task_index = 0
worker.put_index = 0
function_id = task.function_id()
args = task.arguments()
return_object_ids = task.returns()
function_name = worker.function_names[function_id.id()]
# Get task arguments from the object store.
with log_span("ray:task:get_arguments", worker=worker):
arguments = get_arguments_for_execution(worker.functions[function_id.id()], args, worker)
# Execute the task.
with log_span("ray:task:execute", worker=worker):
outputs = worker.functions[function_id.id()].executor(arguments)
# Store the outputs in the local object store.
with log_span("ray:task:store_outputs", worker=worker):
if len(return_object_ids) == 1:
outputs = (outputs,)
store_outputs_in_objstore(return_object_ids, outputs, worker)
except Exception as e:
# We determine whether the exception was caused by the call to
# get_arguments_for_execution or by the execution of the remote function
# or by the call to store_outputs_in_objstore. Depending on which case
# occurred, we format the error message differently.
# whether the variables "arguments" and "outputs" are defined.
if "arguments" in locals() and "outputs" not in locals():
# The error occurred during the task execution.
traceback_str = format_error_message(traceback.format_exc(), task_exception=True)
elif "arguments" in locals() and "outputs" in locals():
# The error occurred after the task executed.
traceback_str = format_error_message(traceback.format_exc())
else:
# The error occurred before the task execution.
traceback_str = None
failure_object = RayTaskError(function_name, e, traceback_str)
failure_objects = [failure_object for _ in range(len(return_object_ids))]
store_outputs_in_objstore(return_object_ids, failure_objects, worker)
# Log the error message.
worker.push_error_to_driver(worker.task_driver_id.id(), "task",
str(failure_object),
data={"function_id": function_id.id(),
"function_name": function_name})
try:
# Reinitialize the values of environment variables that were used in the
# task above so that changes made to their state do not affect other tasks.
with log_span("ray:task:reinitialize_environment_variables", worker=worker):
env._reinitialize()
except Exception as e:
# The attempt to reinitialize the environment variables threw an
# exception. We record the traceback and notify the scheduler.
traceback_str = format_error_message(traceback.format_exc())
worker.push_error_to_driver(worker.task_driver_id.id(),
"reinitialize_environment_variable",
traceback_str,
data={"function_id": function_id.id(),
"function_name": function_name})
check_main_thread()
while True:
with log_span("ray:get_task", worker=worker):
task = worker.photon_client.get_task()
function_id = task.function_id()
# Check that the number of imports we have is at least as great as the
# export counter for the task. If not, wait until we have imported enough.
# We will push warnings to the user if we spend too long in this loop.
with log_span("ray:wait_for_import_counter", worker=worker):
wait_for_valid_import_counter(function_id, task.driver_id().id(), worker=worker)
# Execute the task.
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
# warning to the user if we are waiting too long to acquire the lock because
# that may indicate that the system is hanging, and it'd be good to know
# where the system is hanging.
log(event_type="ray:acquire_lock", kind=LOG_SPAN_START, worker=worker)
with worker.lock:
log(event_type="ray:acquire_lock", kind=LOG_SPAN_END, worker=worker)
contents = {"function_name": worker.function_names[function_id.id()],
"task_id": task.task_id().hex()}
with log_span("ray:task", contents=contents, worker=worker):
process_task(task)
# Push all of the log events to the global state store.
flush_log()
def _submit_task(function_id, func_name, args, worker=global_worker):
"""This is a wrapper around worker.submit_task.
We use this wrapper so that in the remote decorator, we can call _submit_task
instead of worker.submit_task. The difference is that when we attempt to
serialize remote functions, we don't attempt to serialize the worker object,
which cannot be serialized.
"""
return worker.submit_task(function_id, func_name, args)
def _mode(worker=global_worker):
"""This is a wrapper around worker.mode.
We use this wrapper so that in the remote decorator, we can call _mode()
instead of worker.mode. The difference is that when we attempt to serialize
remote functions, we don't attempt to serialize the worker object, which
cannot be serialized.
"""
return worker.mode
def _env():
"""Return the env object.
We use this wrapper because so that functions which use the env object can be
pickled.
"""
return env
def _export_environment_variable(name, environment_variable, worker=global_worker):
"""Export an environment variable to the workers.
This is only called by a driver.
Args:
name (str): The name of the variable to export.
environment_variable (EnvironmentVariable): The environment variable object
containing code for initializing and reinitializing the variable.
"""
check_main_thread()
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
raise Exception("_export_environment_variable can only be called on a driver.")
environment_variable_id = name
key = "EnvironmentVariables:{}".format(environment_variable_id)
worker.redis_client.hmset(key, {"driver_id": worker.task_driver_id.id(),
"name": name,
"initializer": pickling.dumps(environment_variable.initializer),
"reinitializer": pickling.dumps(environment_variable.reinitializer)})
worker.redis_client.rpush("Exports", key)
worker.driver_export_counter += 1
def export_remote_function(function_id, func_name, func, num_return_vals, worker=global_worker):
check_main_thread()
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
raise Exception("export_remote_function can only be called on a driver.")
key = "RemoteFunction:{}".format(function_id.id())
worker.num_return_vals[function_id.id()] = num_return_vals
pickled_func = pickling.dumps(func)
worker.redis_client.hmset(key, {"driver_id": worker.task_driver_id.id(),
"function_id": function_id.id(),
"name": func_name,
"module": func.__module__,
"function": pickled_func,
"num_return_vals": num_return_vals,
"function_export_counter": worker.driver_export_counter})
worker.redis_client.rpush("Exports", key)
worker.driver_export_counter += 1
def remote(*args, **kwargs):
"""This decorator is used to create remote functions.
Args:
num_return_vals (int): The number of object IDs that a call to this function
should return.
"""
worker = global_worker
def make_remote_decorator(num_return_vals, func_id=None):
def remote_decorator(func):
func_name = "{}.{}".format(func.__module__, func.__name__)
if func_id is None:
function_id = FunctionID((hashlib.sha256(func_name.encode("ascii")).digest())[:20])
else:
function_id = func_id
def func_call(*args, **kwargs):
"""This gets run immediately when a worker calls a remote function."""
check_connected()
check_main_thread()
args = list(args)
args.extend([kwargs[keyword] if keyword in kwargs else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments
if any([arg is funcsigs._empty for arg in args]):
raise Exception("Not enough arguments were provided to {}.".format(func_name))
if _mode() == PYTHON_MODE:
# In PYTHON_MODE, remote calls simply execute the function. We copy the
# arguments to prevent the function call from mutating them and to match
# the usual behavior of immutable remote objects.
try:
_env()._running_remote_function_locally = True
result = func(*copy.deepcopy(args))
finally:
_env()._reinitialize()
_env()._running_remote_function_locally = False
return result
objectids = _submit_task(function_id, func_name, args)
if len(objectids) == 1:
return objectids[0]
elif len(objectids) > 1:
return objectids
def func_executor(arguments):
"""This gets run when the remote function is executed."""
start_time = time.time()
result = func(*arguments)
end_time = time.time()
return result
def func_invoker(*args, **kwargs):
"""This is returned by the decorator and used to invoke the function."""
raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name))
func_invoker.remote = func_call
func_invoker.executor = func_executor
func_invoker.is_remote = True
func_name = "{}.{}".format(func.__module__, func.__name__)
func_invoker.func_name = func_name
if sys.version_info >= (3, 0):
func_invoker.__doc__ = func.__doc__
else:
func_invoker.func_doc = func.func_doc
sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.items()]
keyword_defaults = [(k, v.default) for k, v in sig_params]
has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params])
func_invoker.has_vararg_param = has_vararg_param
has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params])
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
# Everything ready - export the function
if worker.mode in [None, SCRIPT_MODE, SILENT_MODE]:
func_name_global_valid = func.__name__ in func.__globals__
func_name_global_value = func.__globals__.get(func.__name__)
# Set the function globally to make it refer to itself
func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable
try:
to_export = pickling.dumps((func, num_return_vals, func.__module__))
finally:
# Undo our changes
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
else: del func.__globals__[func.__name__]
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
export_remote_function(function_id, func_name, func, num_return_vals)
elif worker.mode is None:
worker.cached_remote_functions.append((function_id, func_name, func, num_return_vals))
return func_invoker
return remote_decorator
if _mode() == WORKER_MODE:
if "function_id" in kwargs:
num_return_vals = kwargs["num_return_vals"]
function_id = kwargs["function_id"]
return make_remote_decorator(num_return_vals, function_id)
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
# This is the case where the decorator is just @ray.remote.
num_return_vals = 1
func = args[0]
return make_remote_decorator(num_return_vals)(func)
else:
# This is the case where the decorator is something like
# @ray.remote(num_return_vals=2).
assert len(args) == 0 and "num_return_vals" in kwargs, "The @ray.remote decorator must be applied either with no arguments and no parentheses, for example '@ray.remote', or it must be applied with only the argument num_return_vals, like '@ray.remote(num_return_vals=2)'."
num_return_vals = kwargs["num_return_vals"]
assert not "function_id" in kwargs
return make_remote_decorator(num_return_vals)
def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, name):
"""Check if we support the signature of this function.
We currently do not allow remote functions to have **kwargs. We also do not
support keyword argumens in conjunction with a *args argument.
Args:
has_kwards_param (bool): True if the function being checked has a **kwargs
argument.
has_vararg_param (bool): True if the function being checked has a *args
argument.
keyword_defaults (List): A list of the default values for the arguments to
the function being checked.
name (str): The name of the function to check.
Raises:
Exception: An exception is raised if the signature is not supported.
"""
# check if the user specified kwargs
if has_kwargs_param:
raise "Function {} has a **kwargs argument, which is currently not supported.".format(name)
# check if the user specified a variable number of arguments and any keyword arguments
if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]):
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name)
def get_arguments_for_execution(function, serialized_args, worker=global_worker):
"""Retrieve the arguments for the remote function.
This retrieves the values for the arguments to the remote function that were
passed in as object IDs. Argumens that were passed by value are not changed.
This is called by the worker that is executing the remote function.
Args:
function (Callable): The remote function whose arguments are being
retrieved.
serialized_args (List): The arguments to the function. These are either
strings representing serialized objects passed by value or they are
ObjectIDs.
Returns:
The retrieved arguments in addition to the arguments that were passed by
value.
Raises:
RayGetArgumentError: This exception is raised if a task that created one of
the arguments failed.
"""
arguments = []
for (i, arg) in enumerate(serialized_args):
if isinstance(arg, photon.ObjectID):
# get the object from the local object store
argument = worker.get_object([arg])[0]
if isinstance(argument, RayTaskError):
# If the result is a RayTaskError, then the task that created this
# object failed, and we should propagate the error message here.
raise RayGetArgumentError(function.__name__, i, arg, argument)
else:
# pass the argument by value
argument = arg
arguments.append(argument)
return arguments
def store_outputs_in_objstore(objectids, outputs, worker=global_worker):
"""Store the outputs of a remote function in the local object store.
This stores the values that were returned by a remote function in the local
object store. If any of the return values are object IDs, then these object
IDs are aliased with the object IDs that the scheduler assigned for the return
values. This is called by the worker that executes the remote function.
Note:
The arguments objectids and outputs should have the same length.
Args:
objectids (List[ObjectID]): The object IDs that were assigned to the
outputs of the remote function call.
outputs (Tuple): The value returned by the remote function. If the remote
function was supposed to only return one value, then its output was
wrapped in a tuple with one element prior to being passed into this
function.
"""
for i in range(len(objectids)):
if isinstance(outputs[i], photon.ObjectID):
raise Exception("This remote function returned an ObjectID as its {}th return value. This is not allowed.".format(i))
for i in range(len(objectids)):
worker.put_object(objectids[i], outputs[i])