mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 03:02:56 +08:00
e19e2c6284
* User now only needs to copy url to get to notebook * Fixed duplicate code * Added function to print url * Added exception for calling function on worker * Stored webui url in Redis * Fix linting and simplify code. * Now uses 24 bytes hex token * Fixed python 3 compatibility * Fix linting and python 3 compat * Added comment explaining generating the token. * Removed newline * Small fixes. * Fixed jenkins failure * Rebased and changed formatting * Revert "changed formatting" This reverts commit 226510cf0cdcaab9cf42ad30bd9588a963683592.
2327 lines
103 KiB
Python
2327 lines
103 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import atexit
|
|
import cloudpickle as pickle
|
|
import collections
|
|
import colorama
|
|
import copy
|
|
import hashlib
|
|
import inspect
|
|
import json
|
|
import numpy as np
|
|
import os
|
|
import redis
|
|
import signal
|
|
import sys
|
|
import threading
|
|
import time
|
|
import traceback
|
|
|
|
# Ray modules
|
|
import pyarrow
|
|
import pyarrow.plasma as plasma
|
|
import ray.experimental.state as state
|
|
import ray.serialization as serialization
|
|
import ray.services as services
|
|
import ray.signature as signature
|
|
import ray.local_scheduler
|
|
import ray.plasma
|
|
from ray.utils import FunctionProperties, random_string, binary_to_hex
|
|
|
|
SCRIPT_MODE = 0
|
|
WORKER_MODE = 1
|
|
PYTHON_MODE = 2
|
|
SILENT_MODE = 3
|
|
|
|
LOG_POINT = 0
|
|
LOG_SPAN_START = 1
|
|
LOG_SPAN_END = 2
|
|
|
|
ERROR_KEY_PREFIX = b"Error:"
|
|
DRIVER_ID_LENGTH = 20
|
|
ERROR_ID_LENGTH = 20
|
|
|
|
# This must match the definition of NIL_ACTOR_ID in task.h.
|
|
NIL_ID = 20 * b"\xff"
|
|
NIL_LOCAL_SCHEDULER_ID = NIL_ID
|
|
NIL_FUNCTION_ID = NIL_ID
|
|
NIL_ACTOR_ID = NIL_ID
|
|
|
|
# When performing ray.get, wait 1 second before attemping to reconstruct and
|
|
# fetch the object again.
|
|
GET_TIMEOUT_MILLISECONDS = 1000
|
|
|
|
# This must be kept in sync with the `error_types` array in
|
|
# common/state/error_table.h.
|
|
OBJECT_HASH_MISMATCH_ERROR_TYPE = b"object_hash_mismatch"
|
|
PUT_RECONSTRUCTION_ERROR_TYPE = b"put_reconstruction"
|
|
|
|
# This must be kept in sync with the `scheduling_state` enum in common/task.h.
|
|
TASK_STATUS_RUNNING = 8
|
|
|
|
|
|
class FunctionID(object):
|
|
def __init__(self, function_id):
|
|
self.function_id = function_id
|
|
|
|
def id(self):
|
|
return self.function_id
|
|
|
|
|
|
class RayTaskError(Exception):
|
|
"""An object used internally to represent a task that threw an exception.
|
|
|
|
If a task throws an exception during execution, a RayTaskError is stored in
|
|
the object store for each of the task's outputs. When an object is
|
|
retrieved from the object store, the Python method that retrieved it checks
|
|
to see if the object is a RayTaskError and if it is then an exception is
|
|
thrown propagating the error message.
|
|
|
|
Currently, we either use the exception attribute or the traceback attribute
|
|
but not both.
|
|
|
|
Attributes:
|
|
function_name (str): The name of the function that failed and produced
|
|
the RayTaskError.
|
|
exception (Exception): The exception object thrown by the failed task.
|
|
traceback_str (str): The traceback from the exception.
|
|
"""
|
|
|
|
def __init__(self, function_name, exception, traceback_str):
|
|
"""Initialize a RayTaskError."""
|
|
self.function_name = function_name
|
|
if (isinstance(exception, RayGetError) or
|
|
isinstance(exception, RayGetArgumentError)):
|
|
self.exception = exception
|
|
else:
|
|
self.exception = None
|
|
self.traceback_str = traceback_str
|
|
|
|
def __str__(self):
|
|
"""Format a RayTaskError as a string."""
|
|
if self.traceback_str is None:
|
|
# This path is taken if getting the task arguments failed.
|
|
return ("Remote function {}{}{} failed with:\n\n{}"
|
|
.format(colorama.Fore.RED, self.function_name,
|
|
colorama.Fore.RESET, self.exception))
|
|
else:
|
|
# This path is taken if the task execution failed.
|
|
return ("Remote function {}{}{} failed with:\n\n{}"
|
|
.format(colorama.Fore.RED, self.function_name,
|
|
colorama.Fore.RESET, self.traceback_str))
|
|
|
|
|
|
class RayGetError(Exception):
|
|
"""An exception used when get is called on an output of a failed task.
|
|
|
|
Attributes:
|
|
objectid (lib.ObjectID): The ObjectID that get was called on.
|
|
task_error (RayTaskError): The RayTaskError object created by the
|
|
failed task.
|
|
"""
|
|
|
|
def __init__(self, objectid, task_error):
|
|
"""Initialize a RayGetError object."""
|
|
self.objectid = objectid
|
|
self.task_error = task_error
|
|
|
|
def __str__(self):
|
|
"""Format a RayGetError as a string."""
|
|
return ("Could not get objectid {}. It was created by remote function "
|
|
"{}{}{} which failed with:\n\n{}"
|
|
.format(self.objectid, colorama.Fore.RED,
|
|
self.task_error.function_name, colorama.Fore.RESET,
|
|
self.task_error))
|
|
|
|
|
|
class RayGetArgumentError(Exception):
|
|
"""An exception used when a task's argument was produced by a failed task.
|
|
|
|
Attributes:
|
|
argument_index (int): The index (zero indexed) of the failed argument
|
|
in present task's remote function call.
|
|
function_name (str): The name of the function for the current task.
|
|
objectid (lib.ObjectID): The ObjectID that was passed in as the
|
|
argument.
|
|
task_error (RayTaskError): The RayTaskError object created by the
|
|
failed task.
|
|
"""
|
|
|
|
def __init__(self, function_name, argument_index, objectid, task_error):
|
|
"""Initialize a RayGetArgumentError object."""
|
|
self.argument_index = argument_index
|
|
self.function_name = function_name
|
|
self.objectid = objectid
|
|
self.task_error = task_error
|
|
|
|
def __str__(self):
|
|
"""Format a RayGetArgumentError as a string."""
|
|
return ("Failed to get objectid {} as argument {} for remote function "
|
|
"{}{}{}. It was created by remote function {}{}{} which "
|
|
"failed with:\n{}".format(self.objectid, self.argument_index,
|
|
colorama.Fore.RED,
|
|
self.function_name,
|
|
colorama.Fore.RESET,
|
|
colorama.Fore.RED,
|
|
self.task_error.function_name,
|
|
colorama.Fore.RESET,
|
|
self.task_error))
|
|
|
|
|
|
class Worker(object):
|
|
"""A class used to define the control flow of a worker process.
|
|
|
|
Note:
|
|
The methods in this class are considered unexposed to the user. The
|
|
functions outside of this class are considered exposed.
|
|
|
|
Attributes:
|
|
functions (Dict[str, Callable]): A dictionary mapping the name of a
|
|
remote function to the remote function itself. This is the set of
|
|
remote functions that can be executed by this worker.
|
|
connected (bool): True if Ray has been started and False otherwise.
|
|
mode: The mode of the worker. One of SCRIPT_MODE, PYTHON_MODE,
|
|
SILENT_MODE, and WORKER_MODE.
|
|
cached_remote_functions (List[Tuple[str, str]]): A list of pairs
|
|
representing the remote functions that were defined before the
|
|
worker called connect. The first element is the name of the remote
|
|
function, and the second element is the serialized remote function.
|
|
When the worker eventually does call connect, if it is a driver, it
|
|
will export these functions to the scheduler. If
|
|
cached_remote_functions is None, that means that connect has been
|
|
called already.
|
|
cached_functions_to_run (List): A list of functions to run on all of
|
|
the workers that should be exported as soon as connect is called.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize a Worker object."""
|
|
# The functions field is a dictionary that maps a driver ID to a
|
|
# dictionary of functions that have been registered for that driver
|
|
# (this inner dictionary maps function IDs to a tuple of the function
|
|
# name and the function itself). This should only be used on workers
|
|
# that execute remote functions.
|
|
self.functions = collections.defaultdict(lambda: {})
|
|
# The function_properties field is a dictionary that maps a driver ID
|
|
# to a dictionary of functions that have been registered for that
|
|
# driver (this inner dictionary maps function IDs to a tuple of the
|
|
# number of values returned by that function, the number of CPUs
|
|
# required by that function, and the number of GPUs required by that
|
|
# function). This is used when submitting a function (which can be done
|
|
# both on workers and on drivers).
|
|
self.function_properties = collections.defaultdict(lambda: {})
|
|
# This is a dictionary mapping driver ID to a dictionary that maps
|
|
# remote function IDs for that driver to a counter of the number of
|
|
# times that remote function has been executed on this worker. The
|
|
# counter is incremented every time the function is executed on this
|
|
# worker. When the counter reaches the maximum number of executions
|
|
# allowed for a particular function, the worker is killed.
|
|
self.num_task_executions = collections.defaultdict(lambda: {})
|
|
self.connected = False
|
|
self.mode = None
|
|
self.cached_remote_functions = []
|
|
self.cached_functions_to_run = []
|
|
self.fetch_and_register_actor = None
|
|
self.make_actor = None
|
|
self.actors = {}
|
|
# Use a defaultdict for the actor counts. If this is accessed with a
|
|
# missing key, the default value of 0 is returned, and that key value
|
|
# pair is added to the dict.
|
|
self.actor_counters = collections.defaultdict(lambda: 0)
|
|
|
|
def set_mode(self, mode):
|
|
"""Set the mode of the worker.
|
|
|
|
The mode SCRIPT_MODE should be used if this Worker is a driver that is
|
|
being run as a Python script or interactively in a shell. It will print
|
|
information about task failures.
|
|
|
|
The mode WORKER_MODE should be used if this Worker is not a driver. It
|
|
will not print information about tasks.
|
|
|
|
The mode PYTHON_MODE should be used if this Worker is a driver and if
|
|
you want to run the driver in a manner equivalent to serial Python for
|
|
debugging purposes. It will not send remote function calls to the
|
|
scheduler and will insead execute them in a blocking fashion.
|
|
|
|
The mode SILENT_MODE should be used only during testing. It does not
|
|
print any information about errors because some of the tests
|
|
intentionally fail.
|
|
|
|
args:
|
|
mode: One of SCRIPT_MODE, WORKER_MODE, PYTHON_MODE, and
|
|
SILENT_MODE.
|
|
"""
|
|
self.mode = mode
|
|
|
|
def store_and_register(self, object_id, value, depth=100):
|
|
"""Store an object and attempt to register its class if needed.
|
|
|
|
Args:
|
|
object_id: The ID of the object to store.
|
|
value: The value to put in the object store.
|
|
depth: The maximum number of classes to recursively register.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if the attempt to store the
|
|
object fails. This can happen if there is already an object
|
|
with the same ID in the object store or if the object store is
|
|
full.
|
|
"""
|
|
counter = 0
|
|
while True:
|
|
if counter == depth:
|
|
raise Exception("Ray exceeded the maximum number of classes "
|
|
"that it will recursively serialize when "
|
|
"attempting to serialize an object of "
|
|
"type {}.".format(type(value)))
|
|
counter += 1
|
|
try:
|
|
self.plasma_client.put(value, pyarrow.plasma.ObjectID(
|
|
object_id.id()), self.serialization_context)
|
|
break
|
|
except pyarrow.SerializationCallbackError as e:
|
|
try:
|
|
_register_class(type(e.example_object))
|
|
warning_message = ("WARNING: Serializing objects of type "
|
|
"{} by expanding them as dictionaries "
|
|
"of their fields. This behavior may "
|
|
"be incorrect in some cases."
|
|
.format(type(e.example_object)))
|
|
print(warning_message)
|
|
except serialization.RayNotDictionarySerializable:
|
|
_register_class(type(e.example_object), pickle=True)
|
|
warning_message = ("WARNING: Falling back to serializing "
|
|
"objects of type {} by using pickle. "
|
|
"This may be inefficient."
|
|
.format(type(e.example_object)))
|
|
print(warning_message)
|
|
|
|
def put_object(self, object_id, value):
|
|
"""Put value in the local object store with object id objectid.
|
|
|
|
This assumes that the value for objectid has not yet been placed in the
|
|
local object store.
|
|
|
|
Args:
|
|
object_id (object_id.ObjectID): The object ID of the value to be
|
|
put.
|
|
value: The value to put in the object store.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if the attempt to store the
|
|
object fails. This can happen if there is already an object
|
|
with the same ID in the object store or if the object store is
|
|
full.
|
|
"""
|
|
# Make sure that the value is not an object ID.
|
|
if isinstance(value, ray.local_scheduler.ObjectID):
|
|
raise Exception("Calling 'put' on an ObjectID is not allowed "
|
|
"(similarly, returning an ObjectID from a remote "
|
|
"function is not allowed). If you really want to "
|
|
"do this, you can wrap the ObjectID in a list and "
|
|
"call 'put' on it (or return it).")
|
|
|
|
# Serialize and put the object in the object store.
|
|
try:
|
|
self.store_and_register(object_id, value)
|
|
except pyarrow.PlasmaObjectExists as e:
|
|
# The object already exists in the object store, so there is no
|
|
# need to add it again. TODO(rkn): We need to compare the hashes
|
|
# and make sure that the objects are in fact the same. We also
|
|
# should return an error code to the caller instead of printing a
|
|
# message.
|
|
print("This object already exists in the object store.")
|
|
|
|
def retrieve_and_deserialize(self, object_ids, timeout, error_timeout=10):
|
|
start_time = time.time()
|
|
# Only send the warning once.
|
|
warning_sent = False
|
|
while True:
|
|
try:
|
|
# We divide very large get requests into smaller get requests
|
|
# so that a single get request doesn't block the store for a
|
|
# long time, if the store is blocked, it can block the manager
|
|
# as well as a consequence.
|
|
results = []
|
|
get_request_size = 10000
|
|
for i in range(0, len(object_ids), get_request_size):
|
|
results += self.plasma_client.get(
|
|
object_ids[i:(i + get_request_size)],
|
|
timeout,
|
|
self.serialization_context)
|
|
return results
|
|
except pyarrow.DeserializationCallbackError as e:
|
|
# Wait a little bit for the import thread to import the class.
|
|
# If we currently have the worker lock, we need to release it
|
|
# so that the import thread can acquire it.
|
|
if self.mode == WORKER_MODE:
|
|
self.lock.release()
|
|
time.sleep(0.01)
|
|
if self.mode == WORKER_MODE:
|
|
self.lock.acquire()
|
|
|
|
if time.time() - start_time > error_timeout:
|
|
warning_message = ("This worker or driver is waiting to "
|
|
"receive a class definition so that it "
|
|
"can deserialize an object from the "
|
|
"object store. This may be fine, or it "
|
|
"may be a bug.")
|
|
if not warning_sent:
|
|
self.push_error_to_driver(self.task_driver_id.id(),
|
|
"wait_for_class",
|
|
warning_message)
|
|
warning_sent = True
|
|
|
|
def get_object(self, object_ids):
|
|
"""Get the value or values in the object store associated with the IDs.
|
|
|
|
Return the values from the local object store for object_ids. This will
|
|
block until all the values for object_ids have been written to the
|
|
local object store.
|
|
|
|
Args:
|
|
object_ids (List[object_id.ObjectID]): A list of the object IDs
|
|
whose values should be retrieved.
|
|
"""
|
|
# Make sure that the values are object IDs.
|
|
for object_id in object_ids:
|
|
if not isinstance(object_id, ray.local_scheduler.ObjectID):
|
|
raise Exception("Attempting to call `get` on the value {}, "
|
|
"which is not an ObjectID.".format(object_id))
|
|
# Do an initial fetch for remote objects. We divide the fetch into
|
|
# smaller fetches so as to not block the manager for a prolonged period
|
|
# of time in a single call.
|
|
fetch_request_size = 10000
|
|
plain_object_ids = [plasma.ObjectID(object_id.id())
|
|
for object_id in object_ids]
|
|
for i in range(0, len(object_ids), fetch_request_size):
|
|
self.plasma_client.fetch(
|
|
plain_object_ids[i:(i + fetch_request_size)])
|
|
|
|
# Get the objects. We initially try to get the objects immediately.
|
|
final_results = self.retrieve_and_deserialize(plain_object_ids, 0)
|
|
# Construct a dictionary mapping object IDs that we haven't gotten yet
|
|
# to their original index in the object_ids argument.
|
|
unready_ids = dict((plain_object_ids[i].binary(), i) for (i, val) in
|
|
enumerate(final_results)
|
|
if val is plasma.ObjectNotAvailable)
|
|
was_blocked = (len(unready_ids) > 0)
|
|
# Try reconstructing any objects we haven't gotten yet. Try to get them
|
|
# until at least GET_TIMEOUT_MILLISECONDS milliseconds passes, then
|
|
# repeat.
|
|
while len(unready_ids) > 0:
|
|
for unready_id in unready_ids:
|
|
self.local_scheduler_client.reconstruct_object(unready_id)
|
|
# Do another fetch for objects that aren't available locally yet,
|
|
# in case they were evicted since the last fetch. We divide the
|
|
# fetch into smaller fetches so as to not block the manager for a
|
|
# prolonged period of time in a single call.
|
|
object_ids_to_fetch = list(map(
|
|
plasma.ObjectID, unready_ids.keys()))
|
|
for i in range(0, len(object_ids_to_fetch), fetch_request_size):
|
|
self.plasma_client.fetch(
|
|
object_ids_to_fetch[i:(i + fetch_request_size)])
|
|
results = self.retrieve_and_deserialize(
|
|
object_ids_to_fetch,
|
|
max([GET_TIMEOUT_MILLISECONDS, int(0.01 * len(unready_ids))]))
|
|
# Remove any entries for objects we received during this iteration
|
|
# so we don't retrieve the same object twice.
|
|
for i, val in enumerate(results):
|
|
if val is not plasma.ObjectNotAvailable:
|
|
object_id = object_ids_to_fetch[i].binary()
|
|
index = unready_ids[object_id]
|
|
final_results[index] = val
|
|
unready_ids.pop(object_id)
|
|
|
|
# If there were objects that we weren't able to get locally, let the
|
|
# local scheduler know that we're now unblocked.
|
|
if was_blocked:
|
|
self.local_scheduler_client.notify_unblocked()
|
|
|
|
assert len(final_results) == len(object_ids)
|
|
return final_results
|
|
|
|
def submit_task(self, function_id, args, actor_id=None):
|
|
"""Submit a remote task to the scheduler.
|
|
|
|
Tell the scheduler to schedule the execution of the function with ID
|
|
function_id with arguments args. Retrieve object IDs for the outputs of
|
|
the function from the scheduler and immediately return them.
|
|
|
|
Args:
|
|
args (List[Any]): The arguments to pass into the function.
|
|
Arguments can be object IDs or they can be values. If they are
|
|
values, they must be serializable objecs.
|
|
"""
|
|
with log_span("ray:submit_task", worker=self):
|
|
check_main_thread()
|
|
actor_id = (ray.local_scheduler.ObjectID(NIL_ACTOR_ID)
|
|
if actor_id is None else actor_id)
|
|
# Put large or complex arguments that are passed by value in the
|
|
# object store first.
|
|
args_for_local_scheduler = []
|
|
for arg in args:
|
|
if isinstance(arg, ray.local_scheduler.ObjectID):
|
|
args_for_local_scheduler.append(arg)
|
|
elif ray.local_scheduler.check_simple_value(arg):
|
|
args_for_local_scheduler.append(arg)
|
|
else:
|
|
args_for_local_scheduler.append(put(arg))
|
|
|
|
# Look up the various function properties.
|
|
function_properties = self.function_properties[
|
|
self.task_driver_id.id()][function_id.id()]
|
|
|
|
# Submit the task to local scheduler.
|
|
task = ray.local_scheduler.Task(
|
|
self.task_driver_id,
|
|
ray.local_scheduler.ObjectID(function_id.id()),
|
|
args_for_local_scheduler,
|
|
function_properties.num_return_vals,
|
|
self.current_task_id,
|
|
self.task_index,
|
|
actor_id,
|
|
self.actor_counters[actor_id],
|
|
[function_properties.num_cpus, function_properties.num_gpus,
|
|
function_properties.num_custom_resource])
|
|
# Increment the worker's task index to track how many tasks have
|
|
# been submitted by the current task so far.
|
|
self.task_index += 1
|
|
self.actor_counters[actor_id] += 1
|
|
self.local_scheduler_client.submit(task)
|
|
|
|
return task.returns()
|
|
|
|
def run_function_on_all_workers(self, function):
|
|
"""Run arbitrary code on all of the workers.
|
|
|
|
This function will first be run on the driver, and then it will be
|
|
exported to all of the workers to be run. It will also be run on any
|
|
new workers that register later. If ray.init has not been called yet,
|
|
then cache the function and export it later.
|
|
|
|
Args:
|
|
function (Callable): The function to run on all of the workers. It
|
|
should not take any arguments. If it returns anything, its
|
|
return values will not be used.
|
|
"""
|
|
check_main_thread()
|
|
# If ray.init has not been called yet, then cache the function and
|
|
# export it when connect is called. Otherwise, run the function on all
|
|
# workers.
|
|
if self.mode is None:
|
|
self.cached_functions_to_run.append(function)
|
|
else:
|
|
# Attempt to pickle the function before we need it. This could
|
|
# fail, and it is more convenient if the failure happens before we
|
|
# actually run the function locally.
|
|
pickled_function = pickle.dumps(function)
|
|
|
|
function_to_run_id = random_string()
|
|
key = b"FunctionsToRun:" + function_to_run_id
|
|
# First run the function on the driver. Pass in the number of
|
|
# workers on this node that have already started executing this
|
|
# remote function, and increment that value. Subtract 1 so that the
|
|
# counter starts at 0.
|
|
counter = self.redis_client.hincrby(self.node_ip_address,
|
|
key, 1) - 1
|
|
function({"counter": counter, "worker": self})
|
|
# Run the function on all workers.
|
|
self.redis_client.hmset(key,
|
|
{"driver_id": self.task_driver_id.id(),
|
|
"function_id": function_to_run_id,
|
|
"function": pickled_function})
|
|
self.redis_client.rpush("Exports", key)
|
|
|
|
def push_error_to_driver(self, driver_id, error_type, message, data=None):
|
|
"""Push an error message to the driver to be printed in the background.
|
|
|
|
Args:
|
|
driver_id: The ID of the driver to push the error message to.
|
|
error_type (str): The type of the error.
|
|
message (str): The message that will be printed in the background
|
|
on the driver.
|
|
data: This should be a dictionary mapping strings to strings. It
|
|
will be serialized with json and stored in Redis.
|
|
"""
|
|
error_key = ERROR_KEY_PREFIX + driver_id + b":" + random_string()
|
|
data = {} if data is None else data
|
|
self.redis_client.hmset(error_key, {"type": error_type,
|
|
"message": message,
|
|
"data": data})
|
|
self.redis_client.rpush("ErrorKeys", error_key)
|
|
|
|
def _wait_for_actor(self):
|
|
"""Wait until the actor has been imported."""
|
|
assert self.actor_id != NIL_ACTOR_ID
|
|
# Wait until the actor has been imported.
|
|
while self.actor_id not in self.actors:
|
|
time.sleep(0.001)
|
|
|
|
def _wait_for_function(self, function_id, driver_id, timeout=10):
|
|
"""Wait until the function to be executed is present on this worker.
|
|
|
|
This method will simply loop until the import thread has imported the
|
|
relevant function. If we spend too long in this loop, that may indicate
|
|
a problem somewhere and we will push an error message to the user.
|
|
|
|
If this worker is an actor, then this will wait until the actor has
|
|
been defined.
|
|
|
|
Args:
|
|
is_actor (bool): True if this worker is an actor, and false
|
|
otherwise.
|
|
function_id (str): The ID of the function that we want to execute.
|
|
driver_id (str): The ID of the driver to push the error message to
|
|
if this times out.
|
|
"""
|
|
start_time = time.time()
|
|
# Only send the warning once.
|
|
warning_sent = False
|
|
while True:
|
|
with self.lock:
|
|
if (self.actor_id == NIL_ACTOR_ID and
|
|
(function_id.id() in self.functions[driver_id])):
|
|
break
|
|
elif self.actor_id != NIL_ACTOR_ID and (self.actor_id in
|
|
self.actors):
|
|
break
|
|
if time.time() - start_time > timeout:
|
|
warning_message = ("This worker was asked to execute a "
|
|
"function that it does not have "
|
|
"registered. You may have to restart "
|
|
"Ray.")
|
|
if not warning_sent:
|
|
self.push_error_to_driver(driver_id,
|
|
"wait_for_function",
|
|
warning_message)
|
|
warning_sent = True
|
|
time.sleep(0.001)
|
|
|
|
def _get_arguments_for_execution(self, function_name, serialized_args):
|
|
"""Retrieve the arguments for the remote function.
|
|
|
|
This retrieves the values for the arguments to the remote function that
|
|
were passed in as object IDs. Argumens that were passed by value are
|
|
not changed. This is called by the worker that is executing the remote
|
|
function.
|
|
|
|
Args:
|
|
function_name (str): The name of the remote function whose
|
|
arguments are being retrieved.
|
|
serialized_args (List): The arguments to the function. These are
|
|
either strings representing serialized objects passed by value
|
|
or they are ObjectIDs.
|
|
|
|
Returns:
|
|
The retrieved arguments in addition to the arguments that were
|
|
passed by value.
|
|
|
|
Raises:
|
|
RayGetArgumentError: This exception is raised if a task that
|
|
created one of the arguments failed.
|
|
"""
|
|
arguments = []
|
|
for (i, arg) in enumerate(serialized_args):
|
|
if isinstance(arg, ray.local_scheduler.ObjectID):
|
|
# get the object from the local object store
|
|
argument = self.get_object([arg])[0]
|
|
if isinstance(argument, RayTaskError):
|
|
# If the result is a RayTaskError, then the task that
|
|
# created this object failed, and we should propagate the
|
|
# error message here.
|
|
raise RayGetArgumentError(function_name, i, arg, argument)
|
|
else:
|
|
# pass the argument by value
|
|
argument = arg
|
|
|
|
arguments.append(argument)
|
|
return arguments
|
|
|
|
def _store_outputs_in_objstore(self, objectids, outputs):
|
|
"""Store the outputs of a remote function in the local object store.
|
|
|
|
This stores the values that were returned by a remote function in the
|
|
local object store. If any of the return values are object IDs, then
|
|
these object IDs are aliased with the object IDs that the scheduler
|
|
assigned for the return values. This is called by the worker that
|
|
executes the remote function.
|
|
|
|
Note:
|
|
The arguments objectids and outputs should have the same length.
|
|
|
|
Args:
|
|
objectids (List[ObjectID]): The object IDs that were assigned to
|
|
the outputs of the remote function call.
|
|
outputs (Tuple): The value returned by the remote function. If the
|
|
remote function was supposed to only return one value, then its
|
|
output was wrapped in a tuple with one element prior to being
|
|
passed into this function.
|
|
"""
|
|
for i in range(len(objectids)):
|
|
self.put_object(objectids[i], outputs[i])
|
|
|
|
def _process_task(self, task):
|
|
"""Execute a task assigned to this worker.
|
|
|
|
This method deserializes a task from the scheduler, and attempts to
|
|
execute the task. If the task succeeds, the outputs are stored in the
|
|
local object store. If the task throws an exception, RayTaskError
|
|
objects are stored in the object store to represent the failed task
|
|
(these will be retrieved by calls to get or by subsequent tasks that
|
|
use the outputs of this task).
|
|
"""
|
|
try:
|
|
# The ID of the driver that this task belongs to. This is needed so
|
|
# that if the task throws an exception, we propagate the error
|
|
# message to the correct driver.
|
|
self.task_driver_id = task.driver_id()
|
|
self.current_task_id = task.task_id()
|
|
self.current_function_id = task.function_id().id()
|
|
self.task_index = 0
|
|
self.put_index = 0
|
|
function_id = task.function_id()
|
|
args = task.arguments()
|
|
return_object_ids = task.returns()
|
|
function_name, function_executor = (self.functions
|
|
[self.task_driver_id.id()]
|
|
[function_id.id()])
|
|
|
|
# Get task arguments from the object store.
|
|
with log_span("ray:task:get_arguments", worker=self):
|
|
arguments = self._get_arguments_for_execution(function_name,
|
|
args)
|
|
|
|
# Execute the task.
|
|
with log_span("ray:task:execute", worker=self):
|
|
if task.actor_id().id() == NIL_ACTOR_ID:
|
|
outputs = function_executor.executor(arguments)
|
|
else:
|
|
outputs = function_executor(
|
|
self.actors[task.actor_id().id()], *arguments)
|
|
|
|
# Store the outputs in the local object store.
|
|
with log_span("ray:task:store_outputs", worker=self):
|
|
if len(return_object_ids) == 1:
|
|
outputs = (outputs,)
|
|
self._store_outputs_in_objstore(return_object_ids, outputs)
|
|
except Exception as e:
|
|
# We determine whether the exception was caused by the call to
|
|
# _get_arguments_for_execution or by the execution of the remote
|
|
# function or by the call to _store_outputs_in_objstore. Depending
|
|
# on which case occurred, we format the error message differently.
|
|
# whether the variables "arguments" and "outputs" are defined.
|
|
if "arguments" in locals() and "outputs" not in locals():
|
|
if task.actor_id().id() == NIL_ACTOR_ID:
|
|
# The error occurred during the task execution.
|
|
traceback_str = format_error_message(
|
|
traceback.format_exc(), task_exception=True)
|
|
else:
|
|
# The error occurred during the execution of an actor task.
|
|
traceback_str = format_error_message(
|
|
traceback.format_exc())
|
|
elif "arguments" in locals() and "outputs" in locals():
|
|
# The error occurred after the task executed.
|
|
traceback_str = format_error_message(traceback.format_exc())
|
|
else:
|
|
# The error occurred before the task execution.
|
|
if (isinstance(e, RayGetError) or
|
|
isinstance(e, RayGetArgumentError)):
|
|
# In this case, getting the task arguments failed.
|
|
traceback_str = None
|
|
else:
|
|
traceback_str = traceback.format_exc()
|
|
failure_object = RayTaskError(function_name, e, traceback_str)
|
|
failure_objects = [failure_object for _
|
|
in range(len(return_object_ids))]
|
|
self._store_outputs_in_objstore(return_object_ids, failure_objects)
|
|
# Log the error message.
|
|
self.push_error_to_driver(self.task_driver_id.id(), "task",
|
|
str(failure_object),
|
|
data={"function_id": function_id.id(),
|
|
"function_name": function_name})
|
|
|
|
def _checkpoint_actor_state(self, actor_counter):
|
|
"""Checkpoint the actor state.
|
|
|
|
This currently saves the checkpoint to Redis, but the checkpoint really
|
|
needs to go somewhere else.
|
|
|
|
Args:
|
|
actor_counter: The index of the most recent task that ran on this
|
|
actor.
|
|
"""
|
|
print("Saving actor checkpoint. actor_counter = {}."
|
|
.format(actor_counter))
|
|
actor_key = b"Actor:" + self.actor_id
|
|
checkpoint = self.actors[self.actor_id].__ray_save_checkpoint__()
|
|
# Save the checkpoint in Redis. TODO(rkn): Checkpoints should not
|
|
# be stored in Redis. Fix this.
|
|
self.redis_client.hset(
|
|
actor_key,
|
|
"checkpoint_{}".format(actor_counter),
|
|
checkpoint)
|
|
# Remove the previous checkpoints if there is one.
|
|
checkpoint_indices = [int(key[len(b"checkpoint_"):])
|
|
for key in self.redis_client.hkeys(actor_key)
|
|
if key.startswith(b"checkpoint_")]
|
|
for index in checkpoint_indices:
|
|
if index < actor_counter:
|
|
self.redis_client.hdel(actor_key,
|
|
"checkpoint_{}".format(index))
|
|
|
|
def _wait_for_and_process_task(self, task):
|
|
"""Wait for a task to be ready and process the task.
|
|
|
|
Args:
|
|
task: The task to execute.
|
|
"""
|
|
function_id = task.function_id()
|
|
# Wait until the function to be executed has actually been registered
|
|
# on this worker. We will push warnings to the user if we spend too
|
|
# long in this loop.
|
|
with log_span("ray:wait_for_function", worker=self):
|
|
self._wait_for_function(function_id, task.driver_id().id())
|
|
|
|
# Execute the task.
|
|
# TODO(rkn): Consider acquiring this lock with a timeout and pushing a
|
|
# warning to the user if we are waiting too long to acquire the lock
|
|
# because that may indicate that the system is hanging, and it'd be
|
|
# good to know where the system is hanging.
|
|
log(event_type="ray:acquire_lock", kind=LOG_SPAN_START, worker=self)
|
|
with self.lock:
|
|
log(event_type="ray:acquire_lock", kind=LOG_SPAN_END,
|
|
worker=self)
|
|
|
|
function_name, _ = (self.functions[task.driver_id().id()]
|
|
[function_id.id()])
|
|
contents = {"function_name": function_name,
|
|
"task_id": task.task_id().hex(),
|
|
"worker_id": binary_to_hex(self.worker_id)}
|
|
with log_span("ray:task", contents=contents, worker=self):
|
|
self._process_task(task)
|
|
|
|
# Push all of the log events to the global state store.
|
|
flush_log()
|
|
|
|
# Increase the task execution counter.
|
|
(self.num_task_executions[task.driver_id().id()]
|
|
[function_id.id()]) += 1
|
|
|
|
reached_max_executions = (
|
|
self.num_task_executions[task.driver_id().id()]
|
|
[function_id.id()] ==
|
|
self.function_properties[task.driver_id().id()]
|
|
[function_id.id()].max_calls)
|
|
if reached_max_executions:
|
|
ray.worker.global_worker.local_scheduler_client.disconnect()
|
|
os._exit(0)
|
|
|
|
# Checkpoint the actor state if it is the right time to do so.
|
|
actor_counter = task.actor_counter()
|
|
if (self.actor_id != NIL_ACTOR_ID and
|
|
self.actor_checkpoint_interval != -1 and
|
|
actor_counter % self.actor_checkpoint_interval == 0):
|
|
self._checkpoint_actor_state(actor_counter)
|
|
|
|
def _get_next_task_from_local_scheduler(self):
|
|
"""Get the next task from the local scheduler.
|
|
|
|
Returns:
|
|
A task from the local scheduler.
|
|
"""
|
|
with log_span("ray:get_task", worker=self):
|
|
task = self.local_scheduler_client.get_task()
|
|
return task
|
|
|
|
def main_loop(self):
|
|
"""The main loop a worker runs to receive and execute tasks."""
|
|
|
|
def exit(signum, frame):
|
|
cleanup(worker=self)
|
|
sys.exit(0)
|
|
|
|
signal.signal(signal.SIGTERM, exit)
|
|
|
|
check_main_thread()
|
|
while True:
|
|
task = self._get_next_task_from_local_scheduler()
|
|
self._wait_for_and_process_task(task)
|
|
|
|
|
|
def get_gpu_ids():
|
|
"""Get the IDs of the GPU that are available to the worker.
|
|
|
|
Each ID is an integer in the range [0, NUM_GPUS - 1], where NUM_GPUS is the
|
|
number of GPUs that the node has.
|
|
"""
|
|
return global_worker.local_scheduler_client.gpu_ids()
|
|
|
|
|
|
def _webui_url_helper(client):
|
|
"""Parsing for getting the url of the web UI.
|
|
|
|
Args:
|
|
client: A redis client to use to query the primary Redis shard.
|
|
|
|
Returns:
|
|
The URL of the web UI as a string.
|
|
"""
|
|
result = client.hmget("webui", "url")[0]
|
|
return result.decode("ascii") if result is not None else result
|
|
|
|
|
|
def get_webui_url():
|
|
"""Get the URL to access the web UI.
|
|
|
|
Note that the URL does not specify which node the web UI is on.
|
|
|
|
Returns:
|
|
The URL of the web UI as a string.
|
|
"""
|
|
return _webui_url_helper(global_worker.redis_client)
|
|
|
|
|
|
global_worker = Worker()
|
|
"""Worker: The global Worker object for this worker process.
|
|
|
|
We use a global Worker object to ensure that there is a single worker object
|
|
per worker process.
|
|
"""
|
|
|
|
global_state = state.GlobalState()
|
|
|
|
|
|
class RayConnectionError(Exception):
|
|
pass
|
|
|
|
|
|
def check_main_thread():
|
|
"""Check that we are currently on the main thread.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if this is called on a thread other
|
|
than the main thread.
|
|
"""
|
|
if threading.current_thread().getName() != "MainThread":
|
|
raise Exception("The Ray methods are not thread safe and must be "
|
|
"called from the main thread. This method was called "
|
|
"from thread {}."
|
|
.format(threading.current_thread().getName()))
|
|
|
|
|
|
def check_connected(worker=global_worker):
|
|
"""Check if the worker is connected.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if the worker is not connected.
|
|
"""
|
|
if not worker.connected:
|
|
raise RayConnectionError("This command cannot be called before Ray "
|
|
"has been started. You can start Ray with "
|
|
"'ray.init()'.")
|
|
|
|
|
|
def print_failed_task(task_status):
|
|
"""Print information about failed tasks.
|
|
|
|
Args:
|
|
task_status (Dict): A dictionary containing the name, operationid, and
|
|
error message for a failed task.
|
|
"""
|
|
print("""
|
|
Error: Task failed
|
|
Function Name: {}
|
|
Task ID: {}
|
|
Error Message: \n{}
|
|
""".format(task_status["function_name"], task_status["operationid"],
|
|
task_status["error_message"]))
|
|
|
|
|
|
def error_applies_to_driver(error_key, worker=global_worker):
|
|
"""Return True if the error is for this driver and false otherwise."""
|
|
# TODO(rkn): Should probably check that this is only called on a driver.
|
|
# Check that the error key is formatted as in push_error_to_driver.
|
|
assert len(error_key) == (len(ERROR_KEY_PREFIX) + DRIVER_ID_LENGTH + 1 +
|
|
ERROR_ID_LENGTH), error_key
|
|
# If the driver ID in the error message is a sequence of all zeros, then
|
|
# the message is intended for all drivers.
|
|
generic_driver_id = DRIVER_ID_LENGTH * b"\x00"
|
|
driver_id = error_key[len(ERROR_KEY_PREFIX):(len(ERROR_KEY_PREFIX) +
|
|
DRIVER_ID_LENGTH)]
|
|
return (driver_id == worker.task_driver_id.id() or
|
|
driver_id == generic_driver_id)
|
|
|
|
|
|
def error_info(worker=global_worker):
|
|
"""Return information about failed tasks."""
|
|
check_connected(worker)
|
|
check_main_thread()
|
|
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
|
|
errors = []
|
|
for error_key in error_keys:
|
|
if error_applies_to_driver(error_key, worker=worker):
|
|
error_contents = worker.redis_client.hgetall(error_key)
|
|
# If the error is an object hash mismatch, look up the function
|
|
# name for the nondeterministic task. TODO(rkn): Change this so
|
|
# that we don't have to look up additional information. Ideally all
|
|
# relevant information would already be in error_contents.
|
|
error_type = error_contents[b"type"]
|
|
if error_type in [OBJECT_HASH_MISMATCH_ERROR_TYPE,
|
|
PUT_RECONSTRUCTION_ERROR_TYPE]:
|
|
function_id = error_contents[b"data"]
|
|
if function_id == NIL_FUNCTION_ID:
|
|
function_name = b"Driver"
|
|
else:
|
|
task_driver_id = worker.task_driver_id
|
|
function_name = worker.redis_client.hget(
|
|
(b"RemoteFunction:" + task_driver_id.id() +
|
|
b":" + function_id),
|
|
"name")
|
|
error_contents[b"data"] = function_name
|
|
errors.append(error_contents)
|
|
|
|
return errors
|
|
|
|
|
|
def _initialize_serialization(worker=global_worker):
|
|
"""Initialize the serialization library.
|
|
|
|
This defines a custom serializer for object IDs and also tells ray to
|
|
serialize several exception classes that we define for error handling.
|
|
"""
|
|
worker.serialization_context = pyarrow.SerializationContext()
|
|
|
|
# Define a custom serializer and deserializer for handling Object IDs.
|
|
def objectid_custom_serializer(obj):
|
|
return obj.id()
|
|
|
|
def objectid_custom_deserializer(serialized_obj):
|
|
return ray.local_scheduler.ObjectID(serialized_obj)
|
|
|
|
worker.serialization_context.register_type(
|
|
ray.local_scheduler.ObjectID, 20 * b"\x00", pickle=False,
|
|
custom_serializer=objectid_custom_serializer,
|
|
custom_deserializer=objectid_custom_deserializer)
|
|
|
|
# Define a custom serializer and deserializer for handling numpy arrays
|
|
# that contain objects.
|
|
def array_custom_serializer(obj):
|
|
return obj.tolist(), obj.dtype.str
|
|
|
|
def array_custom_deserializer(serialized_obj):
|
|
return np.array(serialized_obj[0], dtype=np.dtype(serialized_obj[1]))
|
|
|
|
worker.serialization_context.register_type(
|
|
np.ndarray, 20 * b"\x01", pickle=False,
|
|
custom_serializer=array_custom_serializer,
|
|
custom_deserializer=array_custom_deserializer)
|
|
|
|
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
|
|
# These should only be called on the driver because _register_class
|
|
# will export the class to all of the workers.
|
|
_register_class(RayTaskError)
|
|
_register_class(RayGetError)
|
|
_register_class(RayGetArgumentError)
|
|
# Tell Ray to serialize lambdas with pickle.
|
|
_register_class(type(lambda: 0), pickle=True)
|
|
# Tell Ray to serialize sets with pickle.
|
|
_register_class(type(set()), pickle=True)
|
|
# Tell Ray to serialize types with pickle.
|
|
_register_class(type(int), pickle=True)
|
|
|
|
|
|
def get_address_info_from_redis_helper(redis_address, node_ip_address):
|
|
redis_ip_address, redis_port = redis_address.split(":")
|
|
# For this command to work, some other client (on the same machine as
|
|
# Redis) must have run "CONFIG SET protected-mode no".
|
|
redis_client = redis.StrictRedis(host=redis_ip_address,
|
|
port=int(redis_port))
|
|
# The client table prefix must be kept in sync with the file
|
|
# "src/common/redis_module/ray_redis_module.cc" where it is defined.
|
|
REDIS_CLIENT_TABLE_PREFIX = "CL:"
|
|
client_keys = redis_client.keys("{}*".format(REDIS_CLIENT_TABLE_PREFIX))
|
|
# Filter to live clients on the same node and do some basic checking.
|
|
plasma_managers = []
|
|
local_schedulers = []
|
|
for key in client_keys:
|
|
info = redis_client.hgetall(key)
|
|
|
|
# Ignore clients that were deleted.
|
|
deleted = info[b"deleted"]
|
|
deleted = bool(int(deleted))
|
|
if deleted:
|
|
continue
|
|
|
|
assert b"ray_client_id" in info
|
|
assert b"node_ip_address" in info
|
|
assert b"client_type" in info
|
|
if info[b"node_ip_address"].decode("ascii") == node_ip_address:
|
|
if info[b"client_type"].decode("ascii") == "plasma_manager":
|
|
plasma_managers.append(info)
|
|
elif info[b"client_type"].decode("ascii") == "local_scheduler":
|
|
local_schedulers.append(info)
|
|
# Make sure that we got at least one plasma manager and local scheduler.
|
|
assert len(plasma_managers) >= 1
|
|
assert len(local_schedulers) >= 1
|
|
# Build the address information.
|
|
object_store_addresses = []
|
|
for manager in plasma_managers:
|
|
address = manager[b"address"].decode("ascii")
|
|
port = services.get_port(address)
|
|
object_store_addresses.append(
|
|
services.ObjectStoreAddress(
|
|
name=manager[b"store_socket_name"].decode("ascii"),
|
|
manager_name=manager[b"manager_socket_name"].decode("ascii"),
|
|
manager_port=port))
|
|
scheduler_names = [
|
|
scheduler[b"local_scheduler_socket_name"].decode("ascii")
|
|
for scheduler in local_schedulers]
|
|
client_info = {"node_ip_address": node_ip_address,
|
|
"redis_address": redis_address,
|
|
"object_store_addresses": object_store_addresses,
|
|
"local_scheduler_socket_names": scheduler_names,
|
|
# Web UI should be running.
|
|
"webui_url": _webui_url_helper(redis_client)}
|
|
return client_info
|
|
|
|
|
|
def get_address_info_from_redis(redis_address, node_ip_address, num_retries=5):
|
|
counter = 0
|
|
while True:
|
|
try:
|
|
return get_address_info_from_redis_helper(redis_address,
|
|
node_ip_address)
|
|
except Exception as e:
|
|
if counter == num_retries:
|
|
raise
|
|
# Some of the information may not be in Redis yet, so wait a little
|
|
# bit.
|
|
print("Some processes that the driver needs to connect to have "
|
|
"not registered with Redis, so retrying. Have you run "
|
|
"'ray start' on this node?")
|
|
time.sleep(1)
|
|
counter += 1
|
|
|
|
|
|
def _init(address_info=None,
|
|
start_ray_local=False,
|
|
object_id_seed=None,
|
|
num_workers=None,
|
|
num_local_schedulers=None,
|
|
object_store_memory=None,
|
|
driver_mode=SCRIPT_MODE,
|
|
redirect_output=False,
|
|
start_workers_from_local_scheduler=True,
|
|
num_cpus=None,
|
|
num_gpus=None,
|
|
num_custom_resource=None,
|
|
num_redis_shards=None):
|
|
"""Helper method to connect to an existing Ray cluster or start a new one.
|
|
|
|
This method handles two cases. Either a Ray cluster already exists and we
|
|
just attach this driver to it, or we start all of the processes associated
|
|
with a Ray cluster and attach to the newly started cluster.
|
|
|
|
Args:
|
|
address_info (dict): A dictionary with address information for
|
|
processes in a partially-started Ray cluster. If
|
|
start_ray_local=True, any processes not in this dictionary will be
|
|
started. If provided, an updated address_info dictionary will be
|
|
returned to include processes that are newly started.
|
|
start_ray_local (bool): If True then this will start any processes not
|
|
already in address_info, including Redis, a global scheduler, local
|
|
scheduler(s), object store(s), and worker(s). It will also kill
|
|
these processes when Python exits. If False, this will attach to an
|
|
existing Ray cluster.
|
|
object_id_seed (int): Used to seed the deterministic generation of
|
|
object IDs. The same value can be used across multiple runs of the
|
|
same job in order to generate the object IDs in a consistent
|
|
manner. However, the same ID should not be used for different jobs.
|
|
num_workers (int): The number of workers to start. This is only
|
|
provided if start_ray_local is True.
|
|
num_local_schedulers (int): The number of local schedulers to start.
|
|
This is only provided if start_ray_local is True.
|
|
object_store_memory: The amount of memory (in bytes) to start the
|
|
object store with.
|
|
driver_mode (bool): The mode in which to start the driver. This should
|
|
be one of ray.SCRIPT_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE.
|
|
redirect_output (bool): True if stdout and stderr for all the processes
|
|
should be redirected to files and false otherwise.
|
|
start_workers_from_local_scheduler (bool): If this flag is True, then
|
|
start the initial workers from the local scheduler. Else, start
|
|
them from Python. The latter case is for debugging purposes only.
|
|
num_cpus: A list containing the number of CPUs the local schedulers
|
|
should be configured with.
|
|
num_gpus: A list containing the number of GPUs the local schedulers
|
|
should be configured with.
|
|
num_custom_resource: A list containing the quantity of a user-defined
|
|
custom resource that the local schedulers should be configured
|
|
with.
|
|
num_redis_shards: The number of Redis shards to start in addition to
|
|
the primary Redis shard.
|
|
|
|
Returns:
|
|
Address information about the started processes.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if an inappropriate combination of
|
|
arguments is passed in.
|
|
"""
|
|
check_main_thread()
|
|
if driver_mode not in [SCRIPT_MODE, PYTHON_MODE, SILENT_MODE]:
|
|
raise Exception("Driver_mode must be in [ray.SCRIPT_MODE, "
|
|
"ray.PYTHON_MODE, ray.SILENT_MODE].")
|
|
|
|
# Get addresses of existing services.
|
|
if address_info is None:
|
|
address_info = {}
|
|
else:
|
|
assert isinstance(address_info, dict)
|
|
node_ip_address = address_info.get("node_ip_address")
|
|
redis_address = address_info.get("redis_address")
|
|
|
|
# Start any services that do not yet exist.
|
|
if driver_mode == PYTHON_MODE:
|
|
# If starting Ray in PYTHON_MODE, don't start any other processes.
|
|
pass
|
|
elif start_ray_local:
|
|
# In this case, we launch a scheduler, a new object store, and some
|
|
# workers, and we connect to them. We do not launch any processes that
|
|
# are already registered in address_info.
|
|
# Use the address 127.0.0.1 in local mode.
|
|
node_ip_address = ("127.0.0.1" if node_ip_address is None
|
|
else node_ip_address)
|
|
# Use 1 local scheduler if num_local_schedulers is not provided. If
|
|
# existing local schedulers are provided, use that count as
|
|
# num_local_schedulers.
|
|
local_schedulers = address_info.get("local_scheduler_socket_names", [])
|
|
if num_local_schedulers is None:
|
|
if len(local_schedulers) > 0:
|
|
num_local_schedulers = len(local_schedulers)
|
|
else:
|
|
num_local_schedulers = 1
|
|
# Use 1 additional redis shard if num_redis_shards is not provided.
|
|
num_redis_shards = 1 if num_redis_shards is None else num_redis_shards
|
|
# Start the scheduler, object store, and some workers. These will be
|
|
# killed by the call to cleanup(), which happens when the Python script
|
|
# exits.
|
|
address_info = services.start_ray_head(
|
|
address_info=address_info,
|
|
node_ip_address=node_ip_address,
|
|
num_workers=num_workers,
|
|
num_local_schedulers=num_local_schedulers,
|
|
object_store_memory=object_store_memory,
|
|
redirect_output=redirect_output,
|
|
start_workers_from_local_scheduler=(
|
|
start_workers_from_local_scheduler),
|
|
num_cpus=num_cpus,
|
|
num_gpus=num_gpus,
|
|
num_custom_resource=num_custom_resource,
|
|
num_redis_shards=num_redis_shards)
|
|
else:
|
|
if redis_address is None:
|
|
raise Exception("When connecting to an existing cluster, "
|
|
"redis_address must be provided.")
|
|
if num_workers is not None:
|
|
raise Exception("When connecting to an existing cluster, "
|
|
"num_workers must not be provided.")
|
|
if num_local_schedulers is not None:
|
|
raise Exception("When connecting to an existing cluster, "
|
|
"num_local_schedulers must not be provided.")
|
|
if (num_cpus is not None or num_gpus is not None or
|
|
num_custom_resource is not None):
|
|
raise Exception("When connecting to an existing cluster, resource "
|
|
"labels (e.g., num_gpus, num_cpus, "
|
|
"num_custom_resource) must not be provided.")
|
|
if num_redis_shards is not None:
|
|
raise Exception("When connecting to an existing cluster, "
|
|
"num_redis_shards must not be provided.")
|
|
if object_store_memory is not None:
|
|
raise Exception("When connecting to an existing cluster, "
|
|
"object_store_memory must not be provided.")
|
|
# Get the node IP address if one is not provided.
|
|
if node_ip_address is None:
|
|
node_ip_address = services.get_node_ip_address(redis_address)
|
|
# Get the address info of the processes to connect to from Redis.
|
|
address_info = get_address_info_from_redis(redis_address,
|
|
node_ip_address)
|
|
|
|
# Connect this driver to Redis, the object store, and the local scheduler.
|
|
# Choose the first object store and local scheduler if there are multiple.
|
|
# The corresponding call to disconnect will happen in the call to cleanup()
|
|
# when the Python script exits.
|
|
if driver_mode == PYTHON_MODE:
|
|
driver_address_info = {}
|
|
else:
|
|
driver_address_info = {
|
|
"node_ip_address": node_ip_address,
|
|
"redis_address": address_info["redis_address"],
|
|
"store_socket_name": (
|
|
address_info["object_store_addresses"][0].name),
|
|
"manager_socket_name": (
|
|
address_info["object_store_addresses"][0].manager_name),
|
|
"local_scheduler_socket_name": (
|
|
address_info["local_scheduler_socket_names"][0]),
|
|
"webui_url": address_info["webui_url"]}
|
|
connect(driver_address_info, object_id_seed=object_id_seed,
|
|
mode=driver_mode, worker=global_worker, actor_id=NIL_ACTOR_ID)
|
|
return address_info
|
|
|
|
|
|
def init(redis_address=None, node_ip_address=None, object_id_seed=None,
|
|
num_workers=None, driver_mode=SCRIPT_MODE, redirect_output=False,
|
|
num_cpus=None, num_gpus=None, num_custom_resource=None,
|
|
num_redis_shards=None):
|
|
"""Connect to an existing Ray cluster or start one and connect to it.
|
|
|
|
This method handles two cases. Either a Ray cluster already exists and we
|
|
just attach this driver to it, or we start all of the processes associated
|
|
with a Ray cluster and attach to the newly started cluster.
|
|
|
|
Args:
|
|
node_ip_address (str): The IP address of the node that we are on.
|
|
redis_address (str): The address of the Redis server to connect to. If
|
|
this address is not provided, then this command will start Redis, a
|
|
global scheduler, a local scheduler, a plasma store, a plasma
|
|
manager, and some workers. It will also kill these processes when
|
|
Python exits.
|
|
object_id_seed (int): Used to seed the deterministic generation of
|
|
object IDs. The same value can be used across multiple runs of the
|
|
same job in order to generate the object IDs in a consistent
|
|
manner. However, the same ID should not be used for different jobs.
|
|
num_workers (int): The number of workers to start. This is only
|
|
provided if redis_address is not provided.
|
|
driver_mode (bool): The mode in which to start the driver. This should
|
|
be one of ray.SCRIPT_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE.
|
|
redirect_output (bool): True if stdout and stderr for all the processes
|
|
should be redirected to files and false otherwise.
|
|
num_cpus (int): Number of cpus the user wishes all local schedulers to
|
|
be configured with.
|
|
num_gpus (int): Number of gpus the user wishes all local schedulers to
|
|
be configured with.
|
|
num_custom_resource (int): The quantity of a user-defined custom
|
|
resource that the local scheduler should be configured with. This
|
|
flag is experimental and is subject to changes in the future.
|
|
num_redis_shards: The number of Redis shards to start in addition to
|
|
the primary Redis shard.
|
|
|
|
Returns:
|
|
Address information about the started processes.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if an inappropriate combination of
|
|
arguments is passed in.
|
|
"""
|
|
info = {"node_ip_address": node_ip_address,
|
|
"redis_address": redis_address}
|
|
return _init(address_info=info, start_ray_local=(redis_address is None),
|
|
num_workers=num_workers, driver_mode=driver_mode,
|
|
redirect_output=redirect_output, num_cpus=num_cpus,
|
|
num_gpus=num_gpus, num_custom_resource=num_custom_resource,
|
|
num_redis_shards=num_redis_shards)
|
|
|
|
|
|
def cleanup(worker=global_worker):
|
|
"""Disconnect the worker, and terminate any processes started in init.
|
|
|
|
This will automatically run at the end when a Python process that uses Ray
|
|
exits. It is ok to run this twice in a row. Note that we manually call
|
|
services.cleanup() in the tests because we need to start and stop many
|
|
clusters in the tests, but the import and exit only happen once.
|
|
"""
|
|
disconnect(worker)
|
|
if hasattr(worker, "local_scheduler_client"):
|
|
del worker.local_scheduler_client
|
|
if hasattr(worker, "plasma_client"):
|
|
worker.plasma_client.disconnect()
|
|
|
|
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
|
|
# If this is a driver, push the finish time to Redis and clean up any
|
|
# other services that were started with the driver.
|
|
worker.redis_client.hmset(b"Drivers:" + worker.worker_id,
|
|
{"end_time": time.time()})
|
|
services.cleanup()
|
|
else:
|
|
# If this is not a driver, make sure there are no orphan processes,
|
|
# besides possibly the worker itself.
|
|
for process_type, processes in services.all_processes.items():
|
|
if process_type == services.PROCESS_TYPE_WORKER:
|
|
assert(len(processes)) <= 1
|
|
else:
|
|
assert(len(processes) == 0)
|
|
|
|
worker.set_mode(None)
|
|
|
|
|
|
atexit.register(cleanup)
|
|
|
|
# Define a custom excepthook so that if the driver exits with an exception, we
|
|
# can push that exception to Redis.
|
|
normal_excepthook = sys.excepthook
|
|
|
|
|
|
def custom_excepthook(type, value, tb):
|
|
# If this is a driver, push the exception to redis.
|
|
if global_worker.mode in [SCRIPT_MODE, SILENT_MODE]:
|
|
error_message = "".join(traceback.format_tb(tb))
|
|
global_worker.redis_client.hmset(b"Drivers:" + global_worker.worker_id,
|
|
{"exception": error_message})
|
|
# Call the normal excepthook.
|
|
normal_excepthook(type, value, tb)
|
|
|
|
|
|
sys.excepthook = custom_excepthook
|
|
|
|
|
|
def print_error_messages(worker):
|
|
"""Print error messages in the background on the driver.
|
|
|
|
This runs in a separate thread on the driver and prints error messages in
|
|
the background.
|
|
"""
|
|
# TODO(rkn): All error messages should have a "component" field indicating
|
|
# which process the error came from (e.g., a worker or a plasma store).
|
|
# Currently all error messages come from workers.
|
|
|
|
helpful_message = """
|
|
You can inspect errors by running
|
|
|
|
ray.error_info()
|
|
|
|
If this driver is hanging, start a new one with
|
|
|
|
ray.init(redis_address="{}")
|
|
""".format(worker.redis_address)
|
|
|
|
worker.error_message_pubsub_client = worker.redis_client.pubsub()
|
|
# Exports that are published after the call to
|
|
# error_message_pubsub_client.psubscribe and before the call to
|
|
# error_message_pubsub_client.listen will still be processed in the loop.
|
|
worker.error_message_pubsub_client.psubscribe("__keyspace@0__:ErrorKeys")
|
|
num_errors_received = 0
|
|
|
|
# Get the exports that occurred before the call to psubscribe.
|
|
with worker.lock:
|
|
error_keys = worker.redis_client.lrange("ErrorKeys", 0, -1)
|
|
for error_key in error_keys:
|
|
if error_applies_to_driver(error_key, worker=worker):
|
|
error_message = worker.redis_client.hget(
|
|
error_key, "message").decode("ascii")
|
|
print(error_message)
|
|
print(helpful_message)
|
|
num_errors_received += 1
|
|
|
|
try:
|
|
for msg in worker.error_message_pubsub_client.listen():
|
|
with worker.lock:
|
|
for error_key in worker.redis_client.lrange(
|
|
"ErrorKeys", num_errors_received, -1):
|
|
if error_applies_to_driver(error_key, worker=worker):
|
|
error_message = worker.redis_client.hget(
|
|
error_key, "message").decode("ascii")
|
|
print(error_message)
|
|
print(helpful_message)
|
|
num_errors_received += 1
|
|
except redis.ConnectionError:
|
|
# When Redis terminates the listen call will throw a ConnectionError,
|
|
# which we catch here.
|
|
pass
|
|
|
|
|
|
def fetch_and_register_remote_function(key, worker=global_worker):
|
|
"""Import a remote function."""
|
|
(driver_id, function_id_str, function_name,
|
|
serialized_function, num_return_vals, module, num_cpus,
|
|
num_gpus, num_custom_resource, max_calls) = worker.redis_client.hmget(
|
|
key, ["driver_id",
|
|
"function_id",
|
|
"name",
|
|
"function",
|
|
"num_return_vals",
|
|
"module",
|
|
"num_cpus",
|
|
"num_gpus",
|
|
"num_custom_resource",
|
|
"max_calls"])
|
|
function_id = ray.local_scheduler.ObjectID(function_id_str)
|
|
function_name = function_name.decode("ascii")
|
|
function_properties = FunctionProperties(
|
|
num_return_vals=int(num_return_vals),
|
|
num_cpus=int(num_cpus),
|
|
num_gpus=int(num_gpus),
|
|
num_custom_resource=int(num_custom_resource),
|
|
max_calls=int(max_calls))
|
|
module = module.decode("ascii")
|
|
|
|
# This is a placeholder in case the function can't be unpickled. This will
|
|
# be overwritten if the function is successfully registered.
|
|
def f():
|
|
raise Exception("This function was not imported properly.")
|
|
remote_f_placeholder = remote(function_id=function_id)(lambda *xs: f())
|
|
worker.functions[driver_id][function_id.id()] = (function_name,
|
|
remote_f_placeholder)
|
|
worker.function_properties[driver_id][function_id.id()] = (
|
|
function_properties)
|
|
worker.num_task_executions[driver_id][function_id.id()] = 0
|
|
|
|
try:
|
|
function = pickle.loads(serialized_function)
|
|
except:
|
|
# If an exception was thrown when the remote function was imported, we
|
|
# record the traceback and notify the scheduler of the failure.
|
|
traceback_str = format_error_message(traceback.format_exc())
|
|
# Log the error message.
|
|
worker.push_error_to_driver(driver_id, "register_remote_function",
|
|
traceback_str,
|
|
data={"function_id": function_id.id(),
|
|
"function_name": function_name})
|
|
else:
|
|
# TODO(rkn): Why is the below line necessary?
|
|
function.__module__ = module
|
|
worker.functions[driver_id][function_id.id()] = (
|
|
function_name, remote(function_id=function_id)(function))
|
|
# Add the function to the function table.
|
|
worker.redis_client.rpush(b"FunctionTable:" + function_id.id(),
|
|
worker.worker_id)
|
|
|
|
|
|
def fetch_and_execute_function_to_run(key, worker=global_worker):
|
|
"""Run on arbitrary function on the worker."""
|
|
driver_id, serialized_function = worker.redis_client.hmget(
|
|
key, ["driver_id", "function"])
|
|
# Get the number of workers on this node that have already started
|
|
# executing this remote function, and increment that value. Subtract 1 so
|
|
# the counter starts at 0.
|
|
counter = worker.redis_client.hincrby(worker.node_ip_address, key, 1) - 1
|
|
try:
|
|
# Deserialize the function.
|
|
function = pickle.loads(serialized_function)
|
|
# Run the function.
|
|
function({"counter": counter, "worker": worker})
|
|
except:
|
|
# If an exception was thrown when the function was run, we record the
|
|
# traceback and notify the scheduler of the failure.
|
|
traceback_str = traceback.format_exc()
|
|
# Log the error message.
|
|
name = function.__name__ if ("function" in locals() and
|
|
hasattr(function, "__name__")) else ""
|
|
worker.push_error_to_driver(driver_id, "function_to_run",
|
|
traceback_str, data={"name": name})
|
|
|
|
|
|
def import_thread(worker, mode):
|
|
worker.import_pubsub_client = worker.redis_client.pubsub()
|
|
# Exports that are published after the call to
|
|
# import_pubsub_client.psubscribe and before the call to
|
|
# import_pubsub_client.listen will still be processed in the loop.
|
|
worker.import_pubsub_client.psubscribe("__keyspace@0__:Exports")
|
|
# Keep track of the number of imports that we've imported.
|
|
num_imported = 0
|
|
|
|
# Get the exports that occurred before the call to psubscribe.
|
|
with worker.lock:
|
|
export_keys = worker.redis_client.lrange("Exports", 0, -1)
|
|
for key in export_keys:
|
|
num_imported += 1
|
|
|
|
# Handle the driver case first.
|
|
if mode != WORKER_MODE:
|
|
if key.startswith(b"FunctionsToRun"):
|
|
fetch_and_execute_function_to_run(key, worker=worker)
|
|
# Continue because FunctionsToRun are the only things that the
|
|
# driver should import.
|
|
continue
|
|
|
|
if key.startswith(b"RemoteFunction"):
|
|
fetch_and_register_remote_function(key, worker=worker)
|
|
elif key.startswith(b"FunctionsToRun"):
|
|
fetch_and_execute_function_to_run(key, worker=worker)
|
|
elif key.startswith(b"ActorClass"):
|
|
# If this worker is an actor that is supposed to construct this
|
|
# class, fetch the actor and class information and construct
|
|
# the class.
|
|
class_id = key.split(b":", 1)[1]
|
|
if (worker.actor_id != NIL_ACTOR_ID and
|
|
worker.class_id == class_id):
|
|
worker.fetch_and_register_actor(key, worker)
|
|
else:
|
|
raise Exception("This code should be unreachable.")
|
|
|
|
try:
|
|
for msg in worker.import_pubsub_client.listen():
|
|
with worker.lock:
|
|
if msg["type"] == "psubscribe":
|
|
continue
|
|
assert msg["data"] == b"rpush"
|
|
num_imports = worker.redis_client.llen("Exports")
|
|
assert num_imports >= num_imported
|
|
for i in range(num_imported, num_imports):
|
|
num_imported += 1
|
|
key = worker.redis_client.lindex("Exports", i)
|
|
|
|
# Handle the driver case first.
|
|
if mode != WORKER_MODE:
|
|
if key.startswith(b"FunctionsToRun"):
|
|
with log_span("ray:import_function_to_run",
|
|
worker=worker):
|
|
fetch_and_execute_function_to_run(
|
|
key, worker=worker)
|
|
# Continue because FunctionsToRun are the only things
|
|
# that the driver should import.
|
|
continue
|
|
|
|
if key.startswith(b"RemoteFunction"):
|
|
with log_span("ray:import_remote_function",
|
|
worker=worker):
|
|
fetch_and_register_remote_function(key,
|
|
worker=worker)
|
|
elif key.startswith(b"FunctionsToRun"):
|
|
with log_span("ray:import_function_to_run",
|
|
worker=worker):
|
|
fetch_and_execute_function_to_run(key,
|
|
worker=worker)
|
|
elif key.startswith(b"Actor"):
|
|
# Only get the actor if the actor ID matches the actor
|
|
# ID of this worker.
|
|
actor_id, = worker.redis_client.hmget(key, "actor_id")
|
|
if worker.actor_id == actor_id:
|
|
worker.fetch_and_register["Actor"](key, worker)
|
|
else:
|
|
raise Exception("This code should be unreachable.")
|
|
except redis.ConnectionError:
|
|
# When Redis terminates the listen call will throw a ConnectionError,
|
|
# which we catch here.
|
|
pass
|
|
|
|
|
|
def connect(info, object_id_seed=None, mode=WORKER_MODE, worker=global_worker,
|
|
actor_id=NIL_ACTOR_ID):
|
|
"""Connect this worker to the local scheduler, to Plasma, and to Redis.
|
|
|
|
Args:
|
|
info (dict): A dictionary with address of the Redis server and the
|
|
sockets of the plasma store, plasma manager, and local scheduler.
|
|
object_id_seed: A seed to use to make the generation of object IDs
|
|
deterministic.
|
|
mode: The mode of the worker. One of SCRIPT_MODE, WORKER_MODE,
|
|
PYTHON_MODE, and SILENT_MODE.
|
|
actor_id: The ID of the actor running on this worker. If this worker is
|
|
not an actor, then this is NIL_ACTOR_ID.
|
|
"""
|
|
check_main_thread()
|
|
# Do some basic checking to make sure we didn't call ray.init twice.
|
|
error_message = "Perhaps you called ray.init twice by accident?"
|
|
assert not worker.connected, error_message
|
|
assert worker.cached_functions_to_run is not None, error_message
|
|
assert worker.cached_remote_functions is not None, error_message
|
|
# Initialize some fields.
|
|
worker.worker_id = random_string()
|
|
worker.actor_id = actor_id
|
|
worker.connected = True
|
|
worker.set_mode(mode)
|
|
# The worker.events field is used to aggregate logging information and
|
|
# display it in the web UI. Note that Python lists protected by the GIL,
|
|
# which is important because we will append to this field from multiple
|
|
# threads.
|
|
worker.events = []
|
|
# If running Ray in PYTHON_MODE, there is no need to create call
|
|
# create_worker or to start the worker service.
|
|
if mode == PYTHON_MODE:
|
|
return
|
|
# Set the node IP address.
|
|
worker.node_ip_address = info["node_ip_address"]
|
|
worker.redis_address = info["redis_address"]
|
|
|
|
# Create a Redis client.
|
|
redis_ip_address, redis_port = info["redis_address"].split(":")
|
|
worker.redis_client = redis.StrictRedis(host=redis_ip_address,
|
|
port=int(redis_port))
|
|
worker.lock = threading.Lock()
|
|
|
|
# Check the RedirectOutput key in Redis and based on its value redirect
|
|
# worker output and error to their own files.
|
|
if mode == WORKER_MODE:
|
|
# This key is set in services.py when Redis is started.
|
|
redirect_worker_output_val = worker.redis_client.get("RedirectOutput")
|
|
if (redirect_worker_output_val is not None and
|
|
int(redirect_worker_output_val) == 1):
|
|
redirect_worker_output = 1
|
|
else:
|
|
redirect_worker_output = 0
|
|
if redirect_worker_output:
|
|
log_stdout_file, log_stderr_file = services.new_log_files("worker",
|
|
True)
|
|
sys.stdout = log_stdout_file
|
|
sys.stderr = log_stderr_file
|
|
services.record_log_files_in_redis(info["redis_address"],
|
|
info["node_ip_address"],
|
|
[log_stdout_file,
|
|
log_stderr_file])
|
|
|
|
# Create an object for interfacing with the global state.
|
|
global_state._initialize_global_state(redis_ip_address, int(redis_port))
|
|
|
|
# Register the worker with Redis.
|
|
if mode in [SCRIPT_MODE, SILENT_MODE]:
|
|
# The concept of a driver is the same as the concept of a "job".
|
|
# Register the driver/job with Redis here.
|
|
import __main__ as main
|
|
driver_info = {
|
|
"node_ip_address": worker.node_ip_address,
|
|
"driver_id": worker.worker_id,
|
|
"start_time": time.time(),
|
|
"plasma_store_socket": info["store_socket_name"],
|
|
"plasma_manager_socket": info["manager_socket_name"],
|
|
"local_scheduler_socket": info["local_scheduler_socket_name"]}
|
|
driver_info["name"] = (main.__file__ if hasattr(main, "__file__")
|
|
else "INTERACTIVE MODE")
|
|
worker.redis_client.hmset(b"Drivers:" + worker.worker_id, driver_info)
|
|
if not worker.redis_client.exists("webui"):
|
|
worker.redis_client.hmset("webui", {"url": info["webui_url"]})
|
|
is_worker = False
|
|
elif mode == WORKER_MODE:
|
|
# Register the worker with Redis.
|
|
worker_dict = {
|
|
"node_ip_address": worker.node_ip_address,
|
|
"plasma_store_socket": info["store_socket_name"],
|
|
"plasma_manager_socket": info["manager_socket_name"],
|
|
"local_scheduler_socket": info["local_scheduler_socket_name"]}
|
|
if redirect_worker_output:
|
|
worker_dict["stdout_file"] = os.path.abspath(log_stdout_file.name)
|
|
worker_dict["stderr_file"] = os.path.abspath(log_stderr_file.name)
|
|
worker.redis_client.hmset(b"Workers:" + worker.worker_id, worker_dict)
|
|
is_worker = True
|
|
else:
|
|
raise Exception("This code should be unreachable.")
|
|
|
|
# Create an object store client.
|
|
worker.plasma_client = plasma.connect(info["store_socket_name"],
|
|
info["manager_socket_name"],
|
|
64)
|
|
# Create the local scheduler client.
|
|
if worker.actor_id != NIL_ACTOR_ID:
|
|
num_gpus = int(worker.redis_client.hget(b"Actor:" + actor_id,
|
|
"num_gpus"))
|
|
else:
|
|
num_gpus = 0
|
|
worker.local_scheduler_client = ray.local_scheduler.LocalSchedulerClient(
|
|
info["local_scheduler_socket_name"], worker.worker_id, worker.actor_id,
|
|
is_worker, num_gpus)
|
|
|
|
# If this is a driver, set the current task ID, the task driver ID, and set
|
|
# the task index to 0.
|
|
if mode in [SCRIPT_MODE, SILENT_MODE]:
|
|
# If the user provided an object_id_seed, then set the current task ID
|
|
# deterministically based on that seed (without altering the state of
|
|
# the user's random number generator). Otherwise, set the current task
|
|
# ID randomly to avoid object ID collisions.
|
|
numpy_state = np.random.get_state()
|
|
if object_id_seed is not None:
|
|
np.random.seed(object_id_seed)
|
|
else:
|
|
# Try to use true randomness.
|
|
np.random.seed(None)
|
|
worker.current_task_id = ray.local_scheduler.ObjectID(
|
|
np.random.bytes(20))
|
|
# When tasks are executed on remote workers in the context of multiple
|
|
# drivers, the task driver ID is used to keep track of which driver is
|
|
# responsible for the task so that error messages will be propagated to
|
|
# the correct driver.
|
|
worker.task_driver_id = ray.local_scheduler.ObjectID(worker.worker_id)
|
|
# Reset the state of the numpy random number generator.
|
|
np.random.set_state(numpy_state)
|
|
# Set other fields needed for computing task IDs.
|
|
worker.task_index = 0
|
|
worker.put_index = 0
|
|
|
|
# Create an entry for the driver task in the task table. This task is
|
|
# added immediately with status RUNNING. This allows us to push errors
|
|
# related to this driver task back to the driver. For example, if the
|
|
# driver creates an object that is later evicted, we should notify the
|
|
# user that we're unable to reconstruct the object, since we cannot
|
|
# rerun the driver.
|
|
driver_task = ray.local_scheduler.Task(
|
|
worker.task_driver_id,
|
|
ray.local_scheduler.ObjectID(NIL_FUNCTION_ID),
|
|
[],
|
|
0,
|
|
worker.current_task_id,
|
|
worker.task_index,
|
|
ray.local_scheduler.ObjectID(NIL_ACTOR_ID),
|
|
worker.actor_counters[actor_id],
|
|
[0, 0, 0])
|
|
global_state._execute_command(
|
|
driver_task.task_id(),
|
|
"RAY.TASK_TABLE_ADD",
|
|
driver_task.task_id().id(),
|
|
TASK_STATUS_RUNNING,
|
|
NIL_LOCAL_SCHEDULER_ID,
|
|
ray.local_scheduler.task_to_string(driver_task))
|
|
# Set the driver's current task ID to the task ID assigned to the
|
|
# driver task.
|
|
worker.current_task_id = driver_task.task_id()
|
|
|
|
# If this is an actor, get the ID of the corresponding class for the actor.
|
|
if worker.actor_id != NIL_ACTOR_ID:
|
|
actor_key = b"Actor:" + worker.actor_id
|
|
class_id = worker.redis_client.hget(actor_key, "class_id")
|
|
worker.class_id = class_id
|
|
|
|
# Initialize the serialization library. This registers some classes, and so
|
|
# it must be run before we export all of the cached remote functions.
|
|
_initialize_serialization()
|
|
|
|
# Start a thread to import exports from the driver or from other workers.
|
|
# Note that the driver also has an import thread, which is used only to
|
|
# import custom class definitions from calls to _register_class that happen
|
|
# under the hood on workers.
|
|
t = threading.Thread(target=import_thread, args=(worker, mode))
|
|
# Making the thread a daemon causes it to exit when the main thread exits.
|
|
t.daemon = True
|
|
t.start()
|
|
|
|
# If this is a driver running in SCRIPT_MODE, start a thread to print error
|
|
# messages asynchronously in the background. Ideally the scheduler would
|
|
# push messages to the driver's worker service, but we ran into bugs when
|
|
# trying to properly shutdown the driver's worker service, so we are
|
|
# temporarily using this implementation which constantly queries the
|
|
# scheduler for new error messages.
|
|
if mode == SCRIPT_MODE:
|
|
t = threading.Thread(target=print_error_messages, args=(worker,))
|
|
# Making the thread a daemon causes it to exit when the main thread
|
|
# exits.
|
|
t.daemon = True
|
|
t.start()
|
|
|
|
if mode in [SCRIPT_MODE, SILENT_MODE]:
|
|
# Add the directory containing the script that is running to the Python
|
|
# paths of the workers. Also add the current directory. Note that this
|
|
# assumes that the directory structures on the machines in the clusters
|
|
# are the same.
|
|
script_directory = os.path.abspath(os.path.dirname(sys.argv[0]))
|
|
current_directory = os.path.abspath(os.path.curdir)
|
|
worker.run_function_on_all_workers(
|
|
lambda worker_info: sys.path.insert(1, script_directory))
|
|
worker.run_function_on_all_workers(
|
|
lambda worker_info: sys.path.insert(1, current_directory))
|
|
# TODO(rkn): Here we first export functions to run, then remote
|
|
# functions. The order matters. For example, one of the functions to
|
|
# run may set the Python path, which is needed to import a module used
|
|
# to define a remote function. We may want to change the order to
|
|
# simply be the order in which the exports were defined on the driver.
|
|
# In addition, we will need to retain the ability to decide what the
|
|
# first few exports are (mostly to set the Python path). Additionally,
|
|
# note that the first exports to be defined on the driver will be the
|
|
# ones defined in separate modules that are imported by the driver.
|
|
# Export cached functions_to_run.
|
|
for function in worker.cached_functions_to_run:
|
|
worker.run_function_on_all_workers(function)
|
|
# Export cached remote functions to the workers.
|
|
for info in worker.cached_remote_functions:
|
|
(function_id, func_name, func,
|
|
func_invoker, function_properties) = info
|
|
export_remote_function(function_id, func_name, func, func_invoker,
|
|
function_properties, worker)
|
|
worker.cached_functions_to_run = None
|
|
worker.cached_remote_functions = None
|
|
|
|
|
|
def disconnect(worker=global_worker):
|
|
"""Disconnect this worker from the scheduler and object store."""
|
|
# Reset the list of cached remote functions so that if more remote
|
|
# functions are defined and then connect is called again, the remote
|
|
# functions will be exported. This is mostly relevant for the tests.
|
|
worker.connected = False
|
|
worker.cached_functions_to_run = []
|
|
worker.cached_remote_functions = []
|
|
worker.serialization_context = pyarrow.SerializationContext()
|
|
|
|
|
|
def register_class(cls, pickle=False, worker=global_worker):
|
|
raise Exception("The function ray.register_class is deprecated. It should "
|
|
"be safe to remove any calls to this function.")
|
|
|
|
|
|
def _register_class(cls, pickle=False, worker=global_worker):
|
|
"""Enable serialization and deserialization for a particular class.
|
|
|
|
This method runs the register_class function defined below on every worker,
|
|
which will enable ray to properly serialize and deserialize objects of
|
|
this class.
|
|
|
|
Args:
|
|
cls (type): The class that ray should serialize.
|
|
pickle (bool): If False then objects of this class will be serialized
|
|
by turning their __dict__ fields into a dictionary. If True, then
|
|
objects of this class will be serialized using pickle.
|
|
|
|
Raises:
|
|
Exception: An exception is raised if pickle=False and the class cannot
|
|
be efficiently serialized by Ray.
|
|
"""
|
|
class_id = random_string()
|
|
|
|
def register_class_for_serialization(worker_info):
|
|
worker_info["worker"].serialization_context.register_type(
|
|
cls, class_id, pickle=pickle)
|
|
|
|
if not pickle:
|
|
# Raise an exception if cls cannot be serialized efficiently by Ray.
|
|
serialization.check_serializable(cls)
|
|
worker.run_function_on_all_workers(register_class_for_serialization)
|
|
else:
|
|
# Since we are pickling objects of this class, we don't actually need
|
|
# to ship the class definition.
|
|
register_class_for_serialization({"worker": worker})
|
|
|
|
|
|
class RayLogSpan(object):
|
|
"""An object used to enable logging a span of events with a with statement.
|
|
|
|
Attributes:
|
|
event_type (str): The type of the event being logged.
|
|
contents: Additional information to log.
|
|
"""
|
|
def __init__(self, event_type, contents=None, worker=global_worker):
|
|
"""Initialize a RayLogSpan object."""
|
|
self.event_type = event_type
|
|
self.contents = contents
|
|
self.worker = worker
|
|
|
|
def __enter__(self):
|
|
"""Log the beginning of a span event."""
|
|
log(event_type=self.event_type,
|
|
contents=self.contents,
|
|
kind=LOG_SPAN_START,
|
|
worker=self.worker)
|
|
|
|
def __exit__(self, type, value, tb):
|
|
"""Log the end of a span event. Log any exception that occurred."""
|
|
if type is None:
|
|
log(event_type=self.event_type, kind=LOG_SPAN_END,
|
|
worker=self.worker)
|
|
else:
|
|
log(event_type=self.event_type,
|
|
contents={"type": str(type),
|
|
"value": value,
|
|
"traceback": traceback.format_exc()},
|
|
kind=LOG_SPAN_END,
|
|
worker=self.worker)
|
|
|
|
|
|
def log_span(event_type, contents=None, worker=global_worker):
|
|
return RayLogSpan(event_type, contents=contents, worker=worker)
|
|
|
|
|
|
def log_event(event_type, contents=None, worker=global_worker):
|
|
log(event_type, kind=LOG_POINT, contents=contents, worker=worker)
|
|
|
|
|
|
def log(event_type, kind, contents=None, worker=global_worker):
|
|
"""Log an event to the global state store.
|
|
|
|
This adds the event to a buffer of events locally. The buffer can be
|
|
flushed and written to the global state store by calling flush_log().
|
|
|
|
Args:
|
|
event_type (str): The type of the event.
|
|
contents: More general data to store with the event.
|
|
kind (int): Either LOG_POINT, LOG_SPAN_START, or LOG_SPAN_END. This is
|
|
LOG_POINT if the event being logged happens at a single point in
|
|
time. It is LOG_SPAN_START if we are starting to log a span of
|
|
time, and it is LOG_SPAN_END if we are finishing logging a span of
|
|
time.
|
|
"""
|
|
# TODO(rkn): This code currently takes around half a microsecond. Since we
|
|
# call it tens of times per task, this adds up. We will need to redo the
|
|
# logging code, perhaps in C.
|
|
contents = {} if contents is None else contents
|
|
assert isinstance(contents, dict)
|
|
# Make sure all of the keys and values in the dictionary are strings.
|
|
contents = {str(k): str(v) for k, v in contents.items()}
|
|
worker.events.append((time.time(), event_type, kind, contents))
|
|
|
|
|
|
def flush_log(worker=global_worker):
|
|
"""Send the logged worker events to the global state store."""
|
|
event_log_key = b"event_log:" + worker.worker_id
|
|
event_log_value = json.dumps(worker.events)
|
|
worker.local_scheduler_client.log_event(event_log_key,
|
|
event_log_value,
|
|
time.time())
|
|
worker.events = []
|
|
|
|
|
|
def get(object_ids, worker=global_worker):
|
|
"""Get a remote object or a list of remote objects from the object store.
|
|
|
|
This method blocks until the object corresponding to the object ID is
|
|
available in the local object store. If this object is not in the local
|
|
object store, it will be shipped from an object store that has it (once the
|
|
object has been created). If object_ids is a list, then the objects
|
|
corresponding to each object in the list will be returned.
|
|
|
|
Args:
|
|
object_ids: Object ID of the object to get or a list of object IDs to
|
|
get.
|
|
|
|
Returns:
|
|
A Python object or a list of Python objects.
|
|
"""
|
|
check_connected(worker)
|
|
with log_span("ray:get", worker=worker):
|
|
check_main_thread()
|
|
|
|
if worker.mode == PYTHON_MODE:
|
|
# In PYTHON_MODE, ray.get is the identity operation (the input will
|
|
# actually be a value not an objectid).
|
|
return object_ids
|
|
if isinstance(object_ids, list):
|
|
values = worker.get_object(object_ids)
|
|
for i, value in enumerate(values):
|
|
if isinstance(value, RayTaskError):
|
|
raise RayGetError(object_ids[i], value)
|
|
return values
|
|
else:
|
|
value = worker.get_object([object_ids])[0]
|
|
if isinstance(value, RayTaskError):
|
|
# If the result is a RayTaskError, then the task that created
|
|
# this object failed, and we should propagate the error message
|
|
# here.
|
|
raise RayGetError(object_ids, value)
|
|
return value
|
|
|
|
|
|
def put(value, worker=global_worker):
|
|
"""Store an object in the object store.
|
|
|
|
Args:
|
|
value: The Python object to be stored.
|
|
|
|
Returns:
|
|
The object ID assigned to this value.
|
|
"""
|
|
check_connected(worker)
|
|
with log_span("ray:put", worker=worker):
|
|
check_main_thread()
|
|
|
|
if worker.mode == PYTHON_MODE:
|
|
# In PYTHON_MODE, ray.put is the identity operation.
|
|
return value
|
|
object_id = worker.local_scheduler_client.compute_put_id(
|
|
worker.current_task_id, worker.put_index)
|
|
worker.put_object(object_id, value)
|
|
worker.put_index += 1
|
|
return object_id
|
|
|
|
|
|
def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
|
"""Return a list of IDs that are ready and a list of IDs that are not.
|
|
|
|
If timeout is set, the function returns either when the requested number of
|
|
IDs are ready or when the timeout is reached, whichever occurs first. If it
|
|
is not set, the function simply waits until that number of objects is ready
|
|
and returns that exact number of objectids.
|
|
|
|
This method returns two lists. The first list consists of object IDs that
|
|
correspond to objects that are stored in the object store. The second list
|
|
corresponds to the rest of the object IDs (which may or may not be ready).
|
|
|
|
Args:
|
|
object_ids (List[ObjectID]): List of object IDs for objects that may or
|
|
may not be ready. Note that these IDs must be unique.
|
|
num_returns (int): The number of object IDs that should be returned.
|
|
timeout (int): The maximum amount of time in milliseconds to wait
|
|
before returning.
|
|
|
|
Returns:
|
|
A list of object IDs that are ready and a list of the remaining object
|
|
IDs.
|
|
"""
|
|
check_connected(worker)
|
|
with log_span("ray:wait", worker=worker):
|
|
check_main_thread()
|
|
object_id_strs = [plasma.ObjectID(object_id.id())
|
|
for object_id in object_ids]
|
|
timeout = timeout if timeout is not None else 2 ** 30
|
|
ready_ids, remaining_ids = worker.plasma_client.wait(object_id_strs,
|
|
timeout,
|
|
num_returns)
|
|
ready_ids = [ray.local_scheduler.ObjectID(object_id.binary())
|
|
for object_id in ready_ids]
|
|
remaining_ids = [ray.local_scheduler.ObjectID(object_id.binary())
|
|
for object_id in remaining_ids]
|
|
return ready_ids, remaining_ids
|
|
|
|
|
|
def format_error_message(exception_message, task_exception=False):
|
|
"""Improve the formatting of an exception thrown by a remote function.
|
|
|
|
This method takes a traceback from an exception and makes it nicer by
|
|
removing a few uninformative lines and adding some space to indent the
|
|
remaining lines nicely.
|
|
|
|
Args:
|
|
exception_message (str): A message generated by traceback.format_exc().
|
|
|
|
Returns:
|
|
A string of the formatted exception message.
|
|
"""
|
|
lines = exception_message.split("\n")
|
|
if task_exception:
|
|
# For errors that occur inside of tasks, remove lines 1, 2, 3, and 4,
|
|
# which are always the same, they just contain information about the
|
|
# main loop.
|
|
lines = lines[0:1] + lines[5:]
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _submit_task(function_id, args, worker=global_worker):
|
|
"""This is a wrapper around worker.submit_task.
|
|
|
|
We use this wrapper so that in the remote decorator, we can call
|
|
_submit_task instead of worker.submit_task. The difference is that when we
|
|
attempt to serialize remote functions, we don't attempt to serialize the
|
|
worker object, which cannot be serialized.
|
|
"""
|
|
return worker.submit_task(function_id, args)
|
|
|
|
|
|
def _mode(worker=global_worker):
|
|
"""This is a wrapper around worker.mode.
|
|
|
|
We use this wrapper so that in the remote decorator, we can call _mode()
|
|
instead of worker.mode. The difference is that when we attempt to serialize
|
|
remote functions, we don't attempt to serialize the worker object, which
|
|
cannot be serialized.
|
|
"""
|
|
return worker.mode
|
|
|
|
|
|
def export_remote_function(function_id, func_name, func, func_invoker,
|
|
function_properties, worker=global_worker):
|
|
check_main_thread()
|
|
if _mode(worker) not in [SCRIPT_MODE, SILENT_MODE]:
|
|
raise Exception("export_remote_function can only be called on a "
|
|
"driver.")
|
|
|
|
worker.function_properties[
|
|
worker.task_driver_id.id()][function_id.id()] = function_properties
|
|
task_driver_id = worker.task_driver_id
|
|
key = b"RemoteFunction:" + task_driver_id.id() + b":" + function_id.id()
|
|
|
|
# Work around limitations of Python pickling.
|
|
func_name_global_valid = func.__name__ in func.__globals__
|
|
func_name_global_value = func.__globals__.get(func.__name__)
|
|
# Allow the function to reference itself as a global variable
|
|
func.__globals__[func.__name__] = func_invoker
|
|
try:
|
|
pickled_func = pickle.dumps(func)
|
|
finally:
|
|
# Undo our changes
|
|
if func_name_global_valid:
|
|
func.__globals__[func.__name__] = func_name_global_value
|
|
else:
|
|
del func.__globals__[func.__name__]
|
|
|
|
worker.redis_client.hmset(key, {
|
|
"driver_id": worker.task_driver_id.id(),
|
|
"function_id": function_id.id(),
|
|
"name": func_name,
|
|
"module": func.__module__,
|
|
"function": pickled_func,
|
|
"num_return_vals": function_properties.num_return_vals,
|
|
"num_cpus": function_properties.num_cpus,
|
|
"num_gpus": function_properties.num_gpus,
|
|
"num_custom_resource": function_properties.num_custom_resource,
|
|
"max_calls": function_properties.max_calls})
|
|
worker.redis_client.rpush("Exports", key)
|
|
|
|
|
|
def in_ipython():
|
|
"""Return true if we are in an IPython interpreter and false otherwise."""
|
|
try:
|
|
__IPYTHON__
|
|
return True
|
|
except NameError:
|
|
return False
|
|
|
|
|
|
def compute_function_id(func_name, func):
|
|
"""Compute an function ID for a function.
|
|
|
|
Args:
|
|
func_name: The name of the function (this includes the module name plus
|
|
the function name).
|
|
func: The actual function.
|
|
|
|
Returns:
|
|
This returns the function ID.
|
|
"""
|
|
function_id_hash = hashlib.sha1()
|
|
# Include the function name in the hash.
|
|
function_id_hash.update(func_name.encode("ascii"))
|
|
# If we are running a script or are in IPython, include the source code in
|
|
# the hash. If we are in a regular Python interpreter we skip this part
|
|
# because the source code is not accessible.
|
|
import __main__ as main
|
|
if hasattr(main, "__file__") or in_ipython():
|
|
function_id_hash.update(inspect.getsource(func).encode("ascii"))
|
|
# Compute the function ID.
|
|
function_id = function_id_hash.digest()
|
|
assert len(function_id) == 20
|
|
function_id = FunctionID(function_id)
|
|
|
|
return function_id
|
|
|
|
|
|
def remote(*args, **kwargs):
|
|
"""This decorator is used to define remote functions and to define actors.
|
|
|
|
Args:
|
|
num_return_vals (int): The number of object IDs that a call to this
|
|
function should return.
|
|
num_cpus (int): The number of CPUs needed to execute this function.
|
|
num_gpus (int): The number of GPUs needed to execute this function.
|
|
num_custom_resource (int): The quantity of a user-defined custom
|
|
resource that is needed to execute this function. This flag is
|
|
experimental and is subject to changes in the future.
|
|
max_calls (int): The maximum number of tasks of this kind that can be
|
|
run on a worker before the worker needs to be restarted.
|
|
checkpoint_interval (int): The number of tasks to run between
|
|
checkpoints of the actor state.
|
|
"""
|
|
worker = global_worker
|
|
|
|
def make_remote_decorator(num_return_vals, num_cpus, num_gpus,
|
|
num_custom_resource, max_calls,
|
|
checkpoint_interval, func_id=None):
|
|
def remote_decorator(func_or_class):
|
|
if inspect.isfunction(func_or_class):
|
|
function_properties = FunctionProperties(
|
|
num_return_vals=num_return_vals,
|
|
num_cpus=num_cpus,
|
|
num_gpus=num_gpus,
|
|
num_custom_resource=num_custom_resource,
|
|
max_calls=max_calls)
|
|
return remote_function_decorator(func_or_class,
|
|
function_properties)
|
|
if inspect.isclass(func_or_class):
|
|
return worker.make_actor(func_or_class, num_cpus, num_gpus,
|
|
checkpoint_interval)
|
|
raise Exception("The @ray.remote decorator must be applied to "
|
|
"either a function or to a class.")
|
|
|
|
def remote_function_decorator(func, function_properties):
|
|
func_name = "{}.{}".format(func.__module__, func.__name__)
|
|
if func_id is None:
|
|
function_id = compute_function_id(func_name, func)
|
|
else:
|
|
function_id = func_id
|
|
|
|
def func_call(*args, **kwargs):
|
|
"""This runs immediately when a remote function is called."""
|
|
check_connected()
|
|
check_main_thread()
|
|
args = signature.extend_args(function_signature, args, kwargs)
|
|
|
|
if _mode() == PYTHON_MODE:
|
|
# In PYTHON_MODE, remote calls simply execute the function.
|
|
# We copy the arguments to prevent the function call from
|
|
# mutating them and to match the usual behavior of
|
|
# immutable remote objects.
|
|
result = func(*copy.deepcopy(args))
|
|
return result
|
|
objectids = _submit_task(function_id, args)
|
|
if len(objectids) == 1:
|
|
return objectids[0]
|
|
elif len(objectids) > 1:
|
|
return objectids
|
|
|
|
def func_executor(arguments):
|
|
"""This gets run when the remote function is executed."""
|
|
result = func(*arguments)
|
|
return result
|
|
|
|
def func_invoker(*args, **kwargs):
|
|
"""This is used to invoke the function."""
|
|
raise Exception("Remote functions cannot be called directly. "
|
|
"Instead of running '{}()', try '{}.remote()'."
|
|
.format(func_name, func_name))
|
|
func_invoker.remote = func_call
|
|
func_invoker.executor = func_executor
|
|
func_invoker.is_remote = True
|
|
func_name = "{}.{}".format(func.__module__, func.__name__)
|
|
func_invoker.func_name = func_name
|
|
if sys.version_info >= (3, 0):
|
|
func_invoker.__doc__ = func.__doc__
|
|
else:
|
|
func_invoker.func_doc = func.func_doc
|
|
|
|
signature.check_signature_supported(func)
|
|
function_signature = signature.extract_signature(func)
|
|
|
|
# Everything ready - export the function
|
|
if worker.mode in [SCRIPT_MODE, SILENT_MODE]:
|
|
export_remote_function(function_id, func_name, func,
|
|
func_invoker, function_properties)
|
|
elif worker.mode is None:
|
|
worker.cached_remote_functions.append((function_id, func_name,
|
|
func, func_invoker,
|
|
function_properties))
|
|
return func_invoker
|
|
|
|
return remote_decorator
|
|
|
|
num_return_vals = (kwargs["num_return_vals"] if "num_return_vals"
|
|
in kwargs else 1)
|
|
num_cpus = kwargs["num_cpus"] if "num_cpus" in kwargs else 1
|
|
num_gpus = kwargs["num_gpus"] if "num_gpus" in kwargs else 0
|
|
num_custom_resource = (kwargs["num_custom_resource"]
|
|
if "num_custom_resource" in kwargs else 0)
|
|
max_calls = kwargs["max_calls"] if "max_calls" in kwargs else 0
|
|
checkpoint_interval = (kwargs["checkpoint_interval"]
|
|
if "checkpoint_interval" in kwargs else -1)
|
|
|
|
if _mode() == WORKER_MODE:
|
|
if "function_id" in kwargs:
|
|
function_id = kwargs["function_id"]
|
|
return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
|
|
num_custom_resource, max_calls,
|
|
checkpoint_interval, function_id)
|
|
|
|
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
|
# This is the case where the decorator is just @ray.remote.
|
|
return make_remote_decorator(
|
|
num_return_vals, num_cpus,
|
|
num_gpus, num_custom_resource,
|
|
max_calls, checkpoint_interval)(args[0])
|
|
else:
|
|
# This is the case where the decorator is something like
|
|
# @ray.remote(num_return_vals=2).
|
|
error_string = ("The @ray.remote decorator must be applied either "
|
|
"with no arguments and no parentheses, for example "
|
|
"'@ray.remote', or it must be applied using some of "
|
|
"the arguments 'num_return_vals', 'num_cpus', "
|
|
"'num_gpus', num_custom_resource, or 'max_calls', "
|
|
"like '@ray.remote(num_return_vals=2)'.")
|
|
assert (len(args) == 0 and
|
|
("num_return_vals" in kwargs or
|
|
"num_cpus" in kwargs or
|
|
"num_gpus" in kwargs or
|
|
"num_custom_resource" in kwargs or
|
|
"max_calls" in kwargs or
|
|
"checkpoint_interval" in kwargs)), error_string
|
|
for key in kwargs:
|
|
assert key in ["num_return_vals", "num_cpus",
|
|
"num_gpus", "num_custom_resource", "max_calls",
|
|
"checkpoint_interval"], error_string
|
|
assert "function_id" not in kwargs
|
|
return make_remote_decorator(num_return_vals, num_cpus, num_gpus,
|
|
num_custom_resource, max_calls,
|
|
checkpoint_interval)
|